diff options
author | Michael Case <mikecase@google.com> | 2018-06-26 10:56:16 -0700 |
---|---|---|
committer | Michael Case <mikecase@google.com> | 2018-06-26 10:56:16 -0700 |
commit | 7f1056bcc9af72f6ed68939423362e390ce6ad8b (patch) | |
tree | cc434c644a508ac442f79d4463f72c929a017444 /tensorflow | |
parent | 343b373e3386f11a16a5216574492ca56bfd7050 (diff) | |
parent | f2813bf6e4f7f415f012307a03fd5b9fb5822d28 (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow')
208 files changed, 4284 insertions, 1935 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 9d5f98d4d6..a8ad8e4b94 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2414,7 +2414,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; - g->name_map[n->name()] = n; + // We have a convoluted scheme here: Using the C++ graph construction API + // to add potentially many nodes to the graph without running the checks + // (such as uniqueness of the names of nodes) we run with other functions + // that add a node to the graph (like TF_FinishOperation). + if (!g->name_map.insert(std::make_pair(n->name(), n)).second) { + status->status = tensorflow::errors::Internal( + "BUG: The API allowed construction of a graph with duplicate node " + "names (", + n->name(), + "). This is a bug. Please file an issue at " + "https://github.com/tensorflow/tensorflow/issues."); + } } } diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 577f10c5e6..bc04b53fbb 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1160,7 +1160,7 @@ TEST(CAPI, GetOpDef) { } void StringVectorToArrays(const std::vector<string>& v, - std::unique_ptr<const void* []>* ptrs, + std::unique_ptr<const void*[]>* ptrs, std::unique_ptr<size_t[]>* lens) { ptrs->reset(new const void*[v.size()]); lens->reset(new size_t[v.size()]); @@ -1196,7 +1196,7 @@ class CApiColocationTest : public ::testing::Test { void SetViaStringList(TF_OperationDescription* desc, const std::vector<string>& list) { - std::unique_ptr<const void* []> list_ptrs; + std::unique_ptr<const void*[]> list_ptrs; std::unique_ptr<size_t[]> list_lens; StringVectorToArrays(list, &list_ptrs, &list_lens); TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(), @@ -1700,6 +1700,61 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } +void ScalarFloatFromTensor(const TF_Tensor* t, float* f) { + ASSERT_TRUE(t != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(t)); + ASSERT_EQ(0, TF_NumDims(t)); + ASSERT_EQ(4, TF_TensorByteSize(t)); + float* p = static_cast<float*>(TF_TensorData(t)); + *f = *p; +} + +TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) { + const float X = 3.0f, Y = 7.0f; + TF_Operation* x = Placeholder(graph_, s_, "x", TF_FLOAT); + TF_Operation* y = Placeholder(graph_, s_, "y", TF_FLOAT); + TF_Operation* xy = Mul(x, y, graph_, s_, "xy"); + TF_Output dxy_dx, dxy_dy; + + TF_Output outputs[1] = {{xy, 0}}; + TF_Output inputs[1] = {{x, 0}}; + TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + inputs[0] = {y, 0}; + TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_SessionOptions* opts = TF_NewSessionOptions(); + TF_Session* sess = TF_NewSession(graph_, opts, s_); + TF_DeleteSessionOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Output feeds[] = {{x, 0}, {y, 0}}; + TF_Tensor* feedValues[] = {FloatTensor(X), FloatTensor(Y)}; + TF_Output fetches[] = {dxy_dx, dxy_dy}; + TF_Tensor* fetchValues[] = {nullptr, nullptr}; + + TF_SessionRun(sess, nullptr /* run_options */, feeds, feedValues, 2, fetches, + fetchValues, 2, nullptr /* target_opers */, 0, + nullptr /* run_metadata */, s_); + TF_DeleteTensor(feedValues[0]); + TF_DeleteTensor(feedValues[1]); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_DeleteSession(sess, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + float dxy_dxValue = 0.0f, dxy_dyValue = 0.0f; + ScalarFloatFromTensor(fetchValues[0], &dxy_dxValue); + EXPECT_EQ(Y, dxy_dxValue); + + ScalarFloatFromTensor(fetchValues[1], &dxy_dyValue); + EXPECT_EQ(X, dxy_dyValue); + + TF_DeleteTensor(fetchValues[0]); + TF_DeleteTensor(fetchValues[1]); +} + // REGISTER_OP for CApiAttributesTest test cases. // Registers two ops, each with a single attribute called 'v'. // The attribute in one op will have a type 'type', the other @@ -1784,7 +1839,7 @@ TEST_F(CApiAttributesTest, String) { TEST_F(CApiAttributesTest, StringList) { std::vector<string> list = {"bugs", "bunny", "duck"}; - std::unique_ptr<const void* []> list_ptrs; + std::unique_ptr<const void*[]> list_ptrs; std::unique_ptr<size_t[]> list_lens; StringVectorToArrays(list, &list_ptrs, &list_lens); int list_total_size = 0; @@ -1800,7 +1855,7 @@ TEST_F(CApiAttributesTest, StringList) { ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size); - std::unique_ptr<void* []> values(new void*[list.size()]); + std::unique_ptr<void*[]> values(new void*[list.size()]); std::unique_ptr<size_t[]> lens(new size_t[list.size()]); std::unique_ptr<char[]> storage(new char[list_total_size]); TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(), @@ -2025,7 +2080,7 @@ TEST_F(CApiAttributesTest, TensorShapeProtoList) { tensorflow::PartialTensorShape(pts2).AsProto(&proto); proto.SerializeToString(&bytes2); - std::unique_ptr<const void* []> list_ptrs; + std::unique_ptr<const void*[]> list_ptrs; std::unique_ptr<size_t[]> list_lens; const std::vector<string> list = {bytes1, bytes2}; StringVectorToArrays(list, &list_ptrs, &list_lens); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index f3b28c1708..24eb6c069b 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -216,6 +216,13 @@ TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, return MinWithDevice(l, r, graph, /*op_device=*/"", s, name); } +TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + TF_Operation* op; + BinaryOpHelper("Mul", l, r, graph, s, name, &op, "", true); + return op; +} + TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index c16aba666e..38313d647c 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -80,6 +80,9 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "min"); +TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "mul"); + // If `op_device` is non-empty, set the created op on that device. TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, const string& op_device, TF_Status* s, diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 62a889181e..8c886f3171 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -37,6 +37,11 @@ Scope& Scope::operator=(const Scope& other) { return *this; } +namespace { +const char kScopeSeparator[] = "/"; +const char kSuffixSeparator[] = "_"; +} // namespace + Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, bool disable_shape_inference) : graph_(graph), @@ -308,19 +313,23 @@ string Scope::Impl::GetUniqueName(const string& prefix, return prefix; } auto entry = name_map_->find(prefix); - string unique_name = prefix; if (entry == name_map_->end()) { name_map_->insert({prefix, 0}); - } else { - unique_name = strings::StrCat(unique_name, "_", ++entry->second); + return prefix; } + string unique_name; + do { + unique_name = strings::StrCat(prefix, kSuffixSeparator, ++entry->second); + } while (name_map_->find(unique_name) != name_map_->end()); + name_map_->insert({unique_name, 0}); return unique_name; } string Scope::Impl::GetNameForOp(const string& default_name) const { const string unique_name = GetUniqueName(default_name, true /* check_single_use */); - const string sep = name_.empty() || unique_name.empty() ? "" : "/"; + const string sep = + name_.empty() || unique_name.empty() ? "" : kScopeSeparator; return strings::StrCat(name_, sep, unique_name); } @@ -345,7 +354,8 @@ Scope Scope::NewSubScope(const string& child_scope_name) const { } const string unique_name = impl()->GetUniqueName(child_scope_name, false /* check_single_use */); - const string sep = impl()->name_.empty() || unique_name.empty() ? "" : "/"; + const string sep = + impl()->name_.empty() || unique_name.empty() ? "" : kScopeSeparator; return Scope(new Impl(*this, Impl::Tags::ScopeName(), strings::StrCat(impl()->name_, sep, unique_name), false /* copy_names */)); @@ -412,7 +422,7 @@ CompositeOpScopes Scope::GetCompositeOpScopes( if (!impl()->single_use_scope()) { Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name : impl()->op_name_); - const string child_op_sep = impl()->name_.empty() ? "" : "_"; + const string child_op_sep = impl()->name_.empty() ? "" : kSuffixSeparator; const string child_name = strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_); return {child, @@ -435,7 +445,13 @@ class InternalScope { static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) { Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap; for (const Node* node : graph->nodes()) { - (*name_map)[node->name()] = 0; + const string& name = node->name(); + (*name_map)[name] = 0; + // Add all name prefixes ('/' separated). + size_t idx = -1; + while ((idx = name.find(kScopeSeparator, idx + 1)) != string::npos) { + (*name_map)[name.substr(0, idx)] = 0; + } } // We provide null destructors for these shared ptrs (except for name_map) // since the caller owns them and doesn't want the scope to destroy them. diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 8efcfed20d..58adaef2e9 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -34,8 +34,7 @@ class Scope::Impl { // name that has not been used so far in a scope will get no suffix. Later // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes // can share the same NameMap. For instance, a new scope created using - // WithControlDependencies() should would share the same NameMap with the - // parent. + // WithControlDependencies() would share the same NameMap with the parent. typedef std::unordered_map<string, int> NameMap; Impl(const std::shared_ptr<Graph>& graph, diff --git a/tensorflow/cc/framework/scope_test.cc b/tensorflow/cc/framework/scope_test.cc index 9eca9d3fac..b40b345eb8 100644 --- a/tensorflow/cc/framework/scope_test.cc +++ b/tensorflow/cc/framework/scope_test.cc @@ -26,6 +26,16 @@ TEST(ScopeTest, BasicNames) { EXPECT_EQ(root.GetUniqueNameForOp("mul"), "mul"); } +TEST(ScopeTest, OpAndScopeNameCollision) { + Scope root = Scope::NewRootScope(); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo"); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_1"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_1"), "foo_1_1"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2"); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_3"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2_1"); +} + TEST(ScopeTest, HierarchicalNames) { Scope root = Scope::NewRootScope(); Scope child = root.NewSubScope("child"); diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index afb5fa4bb4..b2360dd009 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -27,6 +27,7 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.contrib.signal.python.ops import spectral_ops as signal from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import spectral_ops from tensorflow.python.platform import googletest @@ -97,8 +98,11 @@ class FFTTest(XLATestCase): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) out = signal.stft(ph, ws, hs) + grad = gradients_impl.gradients(out, ph, + grad_ys=array_ops.ones_like(out)) - value = sess.run(out, {ph: data}) + # For gradients, we simply verify that they compile & execute. + value, _ = sess.run([out, grad], {ph: data}) self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) def testFFT(self): diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 4a9c0e7471..772c20fd42 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -21,26 +21,40 @@ from __future__ import print_function import functools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class SegmentReductionOpsTest(XLATestCase): +class SegmentReductionOpsTest(xla_test.XLATestCase): """Test cases for segment reduction ops.""" - def UnsortedSegmentSum(self, data, indices, num_segments): + def _segmentReduction(self, op, data, indices, num_segments): with self.test_session() as sess, self.test_scope(): d = array_ops.placeholder(data.dtype, shape=data.shape) if isinstance(indices, int): i = array_ops.placeholder(np.int32, shape=[]) else: i = array_ops.placeholder(indices.dtype, shape=indices.shape) - return sess.run( - math_ops.unsorted_segment_sum(d, i, num_segments), - {d: data, - i: indices}) + return sess.run(op(d, i, num_segments), {d: data, i: indices}) + + def _unsortedSegmentSum(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_sum, data, indices, + num_segments) + + def _unsortedSegmentProd(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_prod, data, indices, + num_segments) + + def _unsortedSegmentMin(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_min, data, indices, + num_segments) + + def _unsortedSegmentMax(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_max, data, indices, + num_segments) def testUnsortedSegmentSum0DIndices1DData(self): for dtype in self.numeric_types: @@ -49,14 +63,14 @@ class SegmentReductionOpsTest(XLATestCase): [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 0]], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4)) def testUnsortedSegmentSum1DIndices1DData(self): for dtype in self.numeric_types: self.assertAllClose( np.array([1, 3, 2, 9], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) @@ -64,7 +78,7 @@ class SegmentReductionOpsTest(XLATestCase): for dtype in self.numeric_types: self.assertAllClose( np.array([6, 3, 0, 6], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) @@ -76,7 +90,7 @@ class SegmentReductionOpsTest(XLATestCase): dtype=dtype) indices = np.array([8, 1, 0, 3, 7], dtype=np.int32) num_segments = 10 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0], @@ -92,7 +106,7 @@ class SegmentReductionOpsTest(XLATestCase): dtype=dtype) indices = np.array([0, 1, 2, 0, 1], dtype=np.int32) num_segments = 4 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33], @@ -102,30 +116,30 @@ class SegmentReductionOpsTest(XLATestCase): def testUnsortedSegmentSum2DIndices3DData(self): for dtype in self.numeric_types: data = np.array( - [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], - [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], - [310, 311, 312]]], + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ + 200, 201, 202 + ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], dtype=dtype) indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32) num_segments = 8 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( - [[210, 211, 212], [110, 111, 112], [310, 311, 312], - [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301, - 302], [0, 0, 0]], + [[210, 211, 212], [110, 111, 112], [310, 311, 312], [ + 100, 102, 104 + ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]], dtype=dtype), y) def testUnsortedSegmentSum1DIndices3DData(self): for dtype in self.numeric_types: data = np.array( - [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], - [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], - [310, 311, 312]]], + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ + 200, 201, 202 + ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], dtype=dtype) indices = np.array([3, 0, 2, 5], dtype=np.int32) num_segments = 6 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]], @@ -138,10 +152,40 @@ class SegmentReductionOpsTest(XLATestCase): data = np.ones((4, 8, 7), dtype=dtype) indices = np.ones((3, 2), dtype=np.int32) num_segments = 4 - self.assertRaises(ValueError, - functools.partial(self.UnsortedSegmentSum, data, - indices, num_segments)) + self.assertRaises( + ValueError, + functools.partial(self._segmentReduction, + math_ops.unsorted_segment_sum, data, indices, + num_segments)) + + def testUnsortedSegmentOps1DIndices1DDataNegativeIndices(self): + """Tests for min, max, and prod ops. + + These share most of their implementation with sum, so we only test basic + functionality. + """ + for dtype in self.numeric_types: + self.assertAllClose( + np.array([8, 3, 1, 0], dtype=dtype), + self._unsortedSegmentProd( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) + + for dtype in self.int_types | self.float_types: + minval = dtypes.as_dtype(dtype).min + maxval = dtypes.as_dtype(dtype).max + + self.assertAllClose( + np.array([2, 3, maxval, 0], dtype=dtype), + self._unsortedSegmentMin( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) + self.assertAllClose( + np.array([4, 3, minval, 6], dtype=dtype), + self._unsortedSegmentMax( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) -if __name__ == '__main__': +if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 140dad61d9..6cc95149a1 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -166,6 +166,27 @@ StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) { return inserted_node; } +// Check that the graph has no cycle containing the given node. +Status CheckNoCycleContains(const Node* node, const int num_nodes) { + std::vector<const Node*> ready; + ready.push_back(node); + std::vector<bool> visited(num_nodes); + while (!ready.empty()) { + const Node* current_node = ready.back(); + ready.pop_back(); + visited[current_node->id()] = true; + for (const Edge* out : current_node->out_edges()) { + if (out->dst() == node) { + return errors::Internal("Detect a cycle: Node \"", node->name(), "\"(", + node->def().op(), ") feeds into itself."); + } else if (!visited[out->dst()->id()]) { + ready.push_back(out->dst()); + } + } + } + return Status::OK(); +} + StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) { NodeDef arg_def; NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); @@ -1407,6 +1428,10 @@ StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf( TF_RETURN_IF_ERROR( AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node)); TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); + // Check that the if_node doesn't feed into itself. + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNoCycleContains(if_node, graph_->num_node_ids()), + "ConvertToXlaIf failed."); return if_node; } @@ -1506,6 +1531,16 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, worklist.push_back(frame->parent); } } + // There should be no cycle at this point, since while loops have been removed + // from graph. + // Check that the newly added XlaWhile nodes don't feed into themselves. + for (const Node* node : graph->op_nodes()) { + if (node->def().op() == "XlaWhile") { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNoCycleContains(node, graph->num_node_ids()), + "FunctionalizeLoop failed."); + } + } // FunctionalizeControlFlow is invoked for every function, so the loops's // bodies and conditionals that were extracted into functions will be handled diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 14977a908a..aae2f8ee5a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -1012,5 +1013,60 @@ TEST(FunctionalizeControlFlow, Complex) { } } +TEST(FunctionalizeControlFlow, Cycle) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + // ----------------------------------------------------- + // | | + // | v + // less -> switch_1 --> add -> merge_1 -> identity -> switch_2 + // | ^ | + // | | v + // --------> one -------------------------> add_2 ---> merge_2 + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); + auto two = + ops::Const<int32>(scope.WithOpName("cond/two") + .WithControlDependencies(switch_1.output_true), + 2); + auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"), + switch_1.output_true, two); + auto one = + ops::Const<int32>(scope.WithOpName("cond/one") + .WithControlDependencies(switch_1.output_false), + 1); + auto add = ops::Add(scope.WithOpName("cond/false/add"), + switch_1.output_false, one); + + auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"), + std::initializer_list<Input>{add, mul}); + auto identity = + ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output); + auto switch_2 = + ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less); + auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"), + switch_2.output_false, one); + auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"), + switch_2.output_true, two); + auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"), + std::initializer_list<Input>{add_2, mul_2}); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + // No cycle before functionalize control flow. + TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph)); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + // switch_1 and switch_2 have the same switch depth. They are replaced by a + // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle: + // less -> XlaIf <--> identity. + Status status = FunctionalizeControlFlow(graph.get(), &library); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detect a cycle")) + << status.error_message(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 212f6f3966..4a6622ed73 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -87,6 +88,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, } } // namespace Status GraphCompiler::Compile() { + // Check that the graph has no illegal cycles. + TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_)); // Maintain a mapping from node id to node outputs. using NodeOutputs = std::vector<TensorValue>; std::vector<NodeOutputs> output_registry(graph_->num_node_ids()); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index be83834e86..3bab4ae917 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -210,9 +210,7 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaOp min_positive = XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits<float>::min()); auto uniform = b->RngUniform(min_positive, one, xla_shape); - auto truncated_normal_or_status = TruncatedNormal(dtype, uniform, b); - OP_REQUIRES_OK(ctx, truncated_normal_or_status.status()); - ctx->SetOutput(0, truncated_normal_or_status.ValueOrDie()); + ctx->SetOutput(0, TruncatedNormal(dtype, uniform)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 664078ca16..ff14483347 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -22,12 +22,19 @@ limitations under the License. namespace tensorflow { namespace { -class UnsortedSegmentSum : public XlaOpKernel { +class UnsortedSegmentReduce : public XlaOpKernel { public: - explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); } + // The initial value to initialize elements of the output to. + virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; + + // A function to combine two scalars with the same index (e.g., sum). + virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) = 0; + void Compile(XlaOpKernelContext* ctx) override { // output = unsorted_segment_sum(data, indices, num_segments) // Compute a tensor such that: @@ -50,27 +57,29 @@ class UnsortedSegmentSum : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments)); OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(), - errors::InvalidArgument( - "UnsortedSegmentSum requires that indices' rank be" - " less than or equal to data's rank.")); + errors::InvalidArgument(type_string(), + " requires that indices' rank be" + " less than or equal to data's rank.")); // Validate that indices.shape is a prefix of data.shape. for (int d = 0; d < indices_shape.dims(); ++d) { - OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), - errors::InvalidArgument( - "UnsortedSegmentSum requires indices shape to be prefix" - " of data_shape, but dimension ", - d, " differs ", data_shape.dim_size(d), " vs. ", - indices_shape.dim_size(d))); + OP_REQUIRES( + ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), + errors::InvalidArgument(type_string(), + " requires indices shape to be prefix" + " of data_shape, but dimension ", + d, " differs ", data_shape.dim_size(d), + " vs. ", indices_shape.dim_size(d))); } xla::XlaBuilder* builder = ctx->builder(); TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); - auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), - buffer_shape.dim_sizes()); + auto buffer = + builder->Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); - auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { - return builder->Add(a, b); + auto combiner = [this](xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) { + return Combine(a, b, builder); }; auto result = XlaScatter(buffer, /*updates=*/data, indices, @@ -79,13 +88,81 @@ class UnsortedSegmentSum : public XlaOpKernel { ctx->SetOutput(0, result.ValueOrDie()); } - private: + protected: DataType dtype_; }; +class UnsortedSegmentSum : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentSum(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::Zero(builder, dtype_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) override { + return builder->Add(a, b); + }; +}; + REGISTER_XLA_OP( Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"), UnsortedSegmentSum); +class UnsortedSegmentProd : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentProd(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::One(builder, dtype_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) override { + return builder->Mul(a, b); + }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentProd").CompileTimeConstInput("num_segments"), + UnsortedSegmentProd); + +class UnsortedSegmentMin : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentMin(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::MaxFiniteValue(builder, dtype_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) override { + return builder->Min(a, b); + }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentMin").CompileTimeConstInput("num_segments"), + UnsortedSegmentMin); + +class UnsortedSegmentMax : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentMax(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::MinFiniteValue(builder, dtype_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) override { + return builder->Max(a, b); + }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentMax").CompileTimeConstInput("num_segments"), + UnsortedSegmentMax); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 0367501433..43ab4642e9 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -207,10 +207,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) - auto erfinv_or_status = ErfInv(uniform); - OP_REQUIRES_OK(ctx, erfinv_or_status.status()); auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)), - erfinv_or_status.ValueOrDie()); + ErfInv(uniform)); ctx->SetOutput(0, normal); } @@ -245,9 +243,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { auto uniform = RandomUniform(b, seed, shape, std::numeric_limits<float>::min(), 1.0); - auto truncated_normal_or_status = TruncatedNormal(dtype, uniform, b); - OP_REQUIRES_OK(ctx, truncated_normal_or_status.status()); - ctx->SetOutput(0, truncated_normal_or_status.ValueOrDie()); + ctx->SetOutput(0, TruncatedNormal(dtype, uniform)); } private: diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc index 4a2516244a..e4f195901e 100644 --- a/tensorflow/compiler/tf2xla/lib/random.cc +++ b/tensorflow/compiler/tf2xla/lib/random.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" namespace tensorflow { -xla::StatusOr<xla::XlaOp> TruncatedNormal(const DataType dtype, - const xla::XlaOp& uniform, - xla::XlaBuilder* builder) { + +xla::XlaOp TruncatedNormal(const DataType dtype, xla::XlaOp uniform) { + xla::XlaBuilder* builder = uniform.builder(); auto normal_cdf = [](double x) { return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; }; @@ -51,7 +51,7 @@ xla::StatusOr<xla::XlaOp> TruncatedNormal(const DataType dtype, // probit(p) = sqrt(2) * erfinv(2*p-1) auto p = builder->Add(alpha_normal_cdf, builder->Mul(z, uniform)); auto erfinv_input = builder->Sub(builder->Mul(p, two), one); - TF_ASSIGN_OR_RETURN(auto erfinv_or_status, ErfInv(erfinv_input)); - return builder->Mul(sqrt_2, erfinv_or_status); + return builder->Mul(sqrt_2, ErfInv(erfinv_input)); } + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h index 18c873dba5..39cbcf9c5e 100644 --- a/tensorflow/compiler/tf2xla/lib/random.h +++ b/tensorflow/compiler/tf2xla/lib/random.h @@ -21,15 +21,15 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" namespace tensorflow { + // Builds an array filled with values sampled from a truncated normal // distribution such that no values are greater than two or less than negative // two. // // The "uniform" parameter must be an array of random numbers distributed in // (0,1). -xla::StatusOr<xla::XlaOp> TruncatedNormal(DataType dtype, - const xla::XlaOp& uniform, - xla::XlaBuilder* builder); +xla::XlaOp TruncatedNormal(DataType dtype, xla::XlaOp uniform); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index 265b39402c..5f408f2ed0 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -86,10 +86,9 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a); auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index); - TF_ASSERT_OK_AND_ASSIGN( - auto l_index, - DynamicSliceInMinorDims(&builder, a, - {index, builder.ConstantR0<int32>(0)}, {1, 4})); + TF_ASSERT_OK(DynamicSliceInMinorDims( + &builder, a, {index, builder.ConstantR0<int32>(0)}, {1, 4}) + .status()); ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, {a_data.get(), index_data.get()}); @@ -132,9 +131,9 @@ XLA_TEST_F(UtilTest, RowBatchDot) { auto l_index, DynamicSliceInMinorDims(&builder, a, {index, builder.ConstantR0<int32>(0)}, {1, n})); - TF_ASSERT_OK_AND_ASSIGN( - auto dot, BatchDot(&builder, l_index, row, - /*transpose_x=*/false, /*transpose_y=*/true)); + TF_ASSERT_OK(BatchDot(&builder, l_index, row, + /*transpose_x=*/false, /*transpose_y=*/true) + .status()); ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}}, {a_data.get(), row_data.get(), index_data.get()}); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 9c8e56a17e..e646ffe39f 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -384,13 +384,14 @@ Status BuildComputation( const XlaCompiler::Argument& arg = args[resource->arg_num()]; const int core = arg_cores[resource->arg_num()]; DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = resource->value() != resource->initial_value(); + bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); // TensorArray gradients were modified if their values changed or there are // any newly created gradients. for (const auto& grad : resource->tensor_array_gradients()) { - modified = modified || - grad.second->value() != grad.second->initial_value() || - arg.tensor_array_gradients.count(grad.first) == 0; + modified = + modified || + !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || + arg.tensor_array_gradients.count(grad.first) == 0; } if (return_updated_values_for_all_resources || modified) { resource_updates->emplace_back(); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 93cd340485..31115eea60 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -97,12 +97,48 @@ xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) { return b->ConstantLiteral(xla::Literal::MinValue(type)); } +xla::XlaOp XlaHelpers::MinFiniteValue(xla::XlaBuilder* b, DataType data_type) { + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); + switch (type) { + case xla::F16: + return b->ConstantR0<Eigen::half>( + Eigen::NumTraits<Eigen::half>::lowest()); + case xla::BF16: + return b->ConstantR0<bfloat16>(bfloat16::lowest()); + case xla::F32: + return b->ConstantR0<float>(-std::numeric_limits<float>::max()); + case xla::F64: + return b->ConstantR0<double>(-std::numeric_limits<double>::max()); + default: + return b->ConstantLiteral(xla::Literal::MinValue(type)); + } +} + xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::MaxValue(type)); } +xla::XlaOp XlaHelpers::MaxFiniteValue(xla::XlaBuilder* b, DataType data_type) { + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); + switch (type) { + case xla::F16: + return b->ConstantR0<Eigen::half>( + Eigen::NumTraits<Eigen::half>::highest()); + case xla::BF16: + return b->ConstantR0<bfloat16>(bfloat16::highest()); + case xla::F32: + return b->ConstantR0<float>(std::numeric_limits<float>::max()); + case xla::F64: + return b->ConstantR0<double>(std::numeric_limits<double>::max()); + default: + return b->ConstantLiteral(xla::Literal::MaxValue(type)); + } +} + xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); @@ -267,6 +303,8 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, } DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { + // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from + // repeated floating point additions. if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { return DT_FLOAT; } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index c3fdc5252e..c320016998 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -29,13 +29,21 @@ namespace tensorflow { class XlaHelpers { public: // Returns a handle representing the minimum value of a scalar - // element of data_type. + // element of data_type. -inf for floating-point types. static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type); - // Returns a handle representing the maximum value of a scalar + // Returns a handle representing the minimum finite value of a scalar // element of data_type. + static xla::XlaOp MinFiniteValue(xla::XlaBuilder* b, DataType data_type); + + // Returns a handle representing the maximum value of a scalar + // element of data_type. inf for floating point types. static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type); + // Returns a handle representing the maximum finite value of a scalar + // element of data_type. + static xla::XlaOp MaxFiniteValue(xla::XlaBuilder* b, DataType data_type); + // Returns a handle representing the zero value of a scalar // element of data_type. static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index ee6da6a67a..46785bc1f0 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -240,6 +240,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { // a) the types supported by the backend, and // b) the types allowed by the OpDef, and // c) the type constraints. + bool unsatisfiable_type_constraint = false; for (const string& type_attr : type_attrs) { KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); attr_constraint->set_name(type_attr); @@ -276,7 +277,14 @@ void XlaOpRegistry::RegisterCompilationKernels() { if (op_registration->allow_resource_types) { allowed_values->add_type(DT_RESOURCE); } + // Don't build KernelDefs that have unsatisfiable type constraints. + if (allowed_values->type().empty()) { + unsatisfiable_type_constraint = true; + break; + } } + if (unsatisfiable_type_constraint) continue; + if (backend.second.op_filter != nullptr && !backend.second.op_filter(kdef.get())) { continue; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc index 266cbc4395..7b3b15b1af 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc @@ -82,5 +82,38 @@ TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) { } } +// A dummy generic OpKernel for all backends. +class DummyInfeasibleTypeConstraintOp : public XlaOpKernel { + public: + explicit DummyInfeasibleTypeConstraintOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + LOG(FATAL) << "unreachable"; + } +}; + +REGISTER_OP("DummyInfeasibleTypeConstraintOp") + .Attr("T: {float, string}") + .Input("input: T") + .Output("output: T") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); +REGISTER_XLA_OP( + Name("DummyInfeasibleTypeConstraintOp").TypeConstraint("T", DT_STRING), + DummyInfeasibleTypeConstraintOp); + +TEST(XlaOpRegistryTest, OpWithInfeasibleTypeConstraintIsNotRegistered) { + XlaOpRegistry::RegisterCompilationKernels(); + auto registered_kernels = GetAllRegisteredKernels().kernel(); + for (const auto& kernels : registered_kernels) { + // The operator should not be registered. + EXPECT_NE(kernels.op(), "DummyInfeasibleTypeConstraintOp"); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 4525197146..95bd725850 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -175,6 +175,7 @@ cc_library( hdrs = [ "iterator_util.h", "map_util.h", + "overflow_util.h", "ptr_util.h", "util.h", ], @@ -250,7 +251,7 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework_internal", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 8e875bf352..0d7758eef9 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -111,14 +111,17 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder) { }); } -StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder) { - auto f = builder->ConstantR0<bool>(false); - XlaComputation logical_or = CreateScalarOrComputation(builder); - TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, - builder->GetShape(predicates)); - std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape)); - std::iota(all_dimensions.begin(), all_dimensions.end(), 0); - return builder->Reduce(predicates, f, logical_or, all_dimensions); +XlaOp Any(XlaOp predicates) { + XlaBuilder* builder = predicates.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + auto f = builder->ConstantR0<bool>(false); + XlaComputation logical_or = CreateScalarOrComputation(builder); + TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, + builder->GetShape(predicates)); + std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::iota(all_dimensions.begin(), all_dimensions.end(), 0); + return builder->Reduce(predicates, f, logical_or, all_dimensions); + }); } namespace { @@ -164,7 +167,7 @@ std::array<float, 6> kErfUCoefficient = { // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(const XlaOp& x, +XlaOp EvaluatePolynomial(XlaOp x, tensorflow::gtl::ArraySlice<float> coefficients, PrimitiveType data_type) { XlaBuilder* b = x.builder(); @@ -176,7 +179,7 @@ XlaOp EvaluatePolynomial(const XlaOp& x, } // Compute an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(const XlaOp& x, PrimitiveType data_type) { +XlaOp Erfc(XlaOp x, PrimitiveType data_type) { XlaBuilder* b = x.builder(); XlaOp zero = FloatLiteral(b, data_type, 0.0); XlaOp two = FloatLiteral(b, data_type, 2.0); @@ -197,7 +200,7 @@ XlaOp Erfc(const XlaOp& x, PrimitiveType data_type) { } // Compute a polynomial approximation of the error function. -XlaOp Erf(const XlaOp& x, PrimitiveType data_type) { +XlaOp Erf(XlaOp x, PrimitiveType data_type) { XlaBuilder* b = x.builder(); XlaOp z = b->Mul(x, x); XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient, data_type); @@ -217,38 +220,40 @@ XlaOp Erf(const XlaOp& x, PrimitiveType data_type) { // p = sum_{i=1}^n gq[i]*w^i // } // return p*x -StatusOr<XlaOp> ErfInv(const XlaOp& x) { +XlaOp ErfInv(XlaOp x) { XlaBuilder* b = x.builder(); - TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); - constexpr int kDegree = 9; - constexpr std::array<float, 9> w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array<float, 9> w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - auto one = b->ConstantR0<float>(1.0); - auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); - - auto lt = b->Lt(w, b->ConstantR0<float>(5.0)); - auto coefficient = [&](int i) { - return b->Select( - lt, - b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]), - AsInt64Slice(shape.dimensions())), - b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]), - AsInt64Slice(shape.dimensions()))); - }; - w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)), - b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f))); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = b->Add(coefficient(i), b->Mul(p, w)); - } - return b->Mul(p, x); + return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); + constexpr int kDegree = 9; + constexpr std::array<float, 9> w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array<float, 9> w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = b->ConstantR0<float>(1.0); + auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); + + auto lt = b->Lt(w, b->ConstantR0<float>(5.0)); + auto coefficient = [&](int i) { + return b->Select( + lt, + b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]), + AsInt64Slice(shape.dimensions())), + b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]), + AsInt64Slice(shape.dimensions()))); + }; + w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)), + b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f))); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = b->Add(coefficient(i), b->Mul(p, w)); + } + return b->Mul(p, x); + }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 33a8254274..d0e04bbb5e 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -53,22 +53,22 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder); // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false. -StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder); +XlaOp Any(XlaOp predicates); // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(const XlaOp& x, +XlaOp EvaluatePolynomial(XlaOp x, tensorflow::gtl::ArraySlice<float> coefficients, PrimitiveType data_type); // Compute an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(const XlaOp& x, PrimitiveType data_type); +XlaOp Erfc(XlaOp x, PrimitiveType data_type); // Compute an approximation of the error function. -XlaOp Erf(const XlaOp& x, PrimitiveType data_type); +XlaOp Erf(XlaOp x, PrimitiveType data_type); // Compute an approximation of the inverse of the error function. -StatusOr<XlaOp> ErfInv(const XlaOp& x); +XlaOp ErfInv(XlaOp x); } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index 507a2dc5f0..b0f41ac1d3 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -52,6 +52,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shape_inference", diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 256667cbe0..8515d120da 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -21,6 +21,7 @@ limitations under the License. #include <string> #include <utility> +#include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -59,6 +60,54 @@ bool CanBeRoot(HloOpcode opcode) { } // namespace +XlaOp operator-(const XlaOp& x) { return x.builder()->Neg(x); } +XlaOp operator+(const XlaOp& x, const XlaOp& y) { + return x.builder()->Add(x, y); +} +XlaOp operator-(const XlaOp& x, const XlaOp& y) { + return x.builder()->Sub(x, y); +} +XlaOp operator*(const XlaOp& x, const XlaOp& y) { + return x.builder()->Mul(x, y); +} +XlaOp operator/(const XlaOp& x, const XlaOp& y) { + return x.builder()->Div(x, y); +} +XlaOp operator%(const XlaOp& x, const XlaOp& y) { + return x.builder()->Rem(x, y); +} + +XlaOp operator~(const XlaOp& x) { return x.builder()->Not(x); } +XlaOp operator&(const XlaOp& x, const XlaOp& y) { + return x.builder()->And(x, y); +} +XlaOp operator|(const XlaOp& x, const XlaOp& y) { + return x.builder()->Or(x, y); +} +XlaOp operator^(const XlaOp& x, const XlaOp& y) { + return x.builder()->Xor(x, y); +} +XlaOp operator<<(const XlaOp& x, const XlaOp& y) { + return x.builder()->ShiftLeft(x, y); +} + +XlaOp operator>>(const XlaOp& x, const XlaOp& y) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Argument to >> operator does not have an integral type (%s).", + ShapeUtil::HumanString(shape).c_str()); + } + if (ShapeUtil::ElementIsSigned(shape)) { + return builder->ShiftRightArithmetic(x, y); + } else { + return builder->ShiftRightLogical(x, y); + } + }); +} + StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); @@ -81,7 +130,7 @@ XlaBuilder::XlaBuilder(const string& computation_name) XlaBuilder::~XlaBuilder() {} -void XlaBuilder::NoteError(const Status& error) { +XlaOp XlaBuilder::ReportError(const Status& error) { CHECK(!error.ok()); if (die_immediately_on_error_) { LOG(FATAL) << "error building computation: " << error; @@ -91,19 +140,22 @@ void XlaBuilder::NoteError(const Status& error) { first_error_ = error; first_error_backtrace_.CreateCurrent(/*skip_count=*/1); } + return XlaOp(this); } -XlaOp XlaBuilder::NoteErrorOrReturn( - const std::function<StatusOr<XlaOp>()>& op_creator) { +XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) { if (!first_error_.ok()) { return XlaOp(this); } - auto op = op_creator(); if (!op.ok()) { - NoteError(op.status()); - return XlaOp(this); + return ReportError(op.status()); } - return op.ConsumeValueOrDie(); + return op.ValueOrDie(); +} + +XlaOp XlaBuilder::ReportErrorOrReturn( + const std::function<StatusOr<XlaOp>()>& op_creator) { + return ReportErrorOrReturn(op_creator()); } StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const { @@ -207,7 +259,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); if (!build_status.ok()) { - parent_builder_->NoteError( + parent_builder_->ReportError( AddStatus(build_status.status(), tensorflow::strings::StrCat("error from: ", name_))); return {}; @@ -315,7 +367,7 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape, } XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), @@ -327,7 +379,7 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { XlaOp XlaBuilder::BinaryOp( HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -383,7 +435,7 @@ XlaOp XlaBuilder::BinaryOp( XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -430,7 +482,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; *instr.mutable_shape() = literal.shape(); *instr.mutable_literal() = literal.ToProto(); @@ -440,7 +492,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { XlaOp XlaBuilder::Call(const XlaComputation& computation, tensorflow::gtl::ArraySlice<XlaOp> operands) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; std::vector<const Shape*> operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); @@ -461,7 +513,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; if (!parameter_numbers_.insert(parameter_number).second) { return InvalidArgument("parameter %lld already registered", @@ -476,7 +528,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, XlaOp XlaBuilder::Broadcast( const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( const Shape& shape, @@ -510,7 +562,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> start_indices, tensorflow::gtl::ArraySlice<int64> limit_indices, tensorflow::gtl::ArraySlice<int64> strides) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -530,7 +582,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); std::vector<int64> starts(ShapeUtil::Rank(shape), 0); std::vector<int64> limits(shape.dimensions().begin(), @@ -545,7 +597,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, tensorflow::gtl::ArraySlice<int64> slice_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -566,7 +618,7 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -584,7 +636,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands, int64 dimension) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; std::vector<const Shape*> operand_shape_ptrs; @@ -603,7 +655,7 @@ XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands, XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -624,7 +676,7 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions, tensorflow::gtl::ArraySlice<int64> new_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, ShapeInference::InferReshapeShape( @@ -638,7 +690,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> new_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector<int64> dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); @@ -648,7 +700,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Collapse(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus // enqueueing a trivial reshape. @@ -690,7 +742,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { - NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); @@ -704,7 +756,7 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, } XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; std::vector<const Shape*> operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); @@ -718,7 +770,7 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) { } XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data)); if (!ShapeUtil::IsTuple(tuple_shape)) { @@ -767,7 +819,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); DotDimensionNumbers dimension_numbers; @@ -780,7 +832,7 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -859,7 +911,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -905,7 +957,7 @@ XlaOp XlaBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -992,7 +1044,7 @@ StatusOr<Window> XlaBuilder::MakeWindow( XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, const tensorflow::gtl::ArraySlice<int64> fft_length) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1009,23 +1061,69 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, } XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Given shape to Infeed must have a layout"); } - *instr.mutable_shape() = shape; + const Shape infeed_instruction_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + *instr.mutable_shape() = infeed_instruction_shape; instr.set_infeed_config(config); - return AddInstruction(std::move(instr), HloOpcode::kInfeed); + + if (ShapeUtil::IsArray(shape) && sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { + // TODO(b/110793772): Support tiled array-shaped infeeds. + return InvalidArgument( + "Tiled sharding is not yet supported for array-shaped infeeds"); + } + + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + return InvalidArgument( + "Replicated sharding is not yet supported for infeeds"); + } + + // The sharding is set by the client according to the data tuple shape. + // However, the shape of the infeed instruction is a tuple containing the + // data and a token. For tuple sharding type, the sharding must be changed + // to accommodate the token. + XlaOp infeed; + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_TUPLE) { + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + OpSharding infeed_instruction_sharding = *sharding(); + // Arbitrarily assign the token to device 0. + *infeed_instruction_sharding.add_tuple_shardings() = + sharding_builder::AssignDevice(0); + XlaScopedShardingAssignment scoped_sharding(this, + infeed_instruction_sharding); + TF_ASSIGN_OR_RETURN(infeed, + AddInstruction(std::move(instr), HloOpcode::kInfeed)); + } else { + TF_ASSIGN_OR_RETURN(infeed, + AddInstruction(std::move(instr), HloOpcode::kInfeed)); + } + + // The infeed instruction produces a tuple of the infed data and a token + // type. Return XLA op containing the data. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto infeed_data; + *infeed_data.mutable_shape() = shape; + infeed_data.set_tuple_index(0); + return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement, + {infeed}); }); } void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config) { - NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { @@ -1042,14 +1140,33 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, instr.set_outfeed_config(outfeed_config); - return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}); + TF_RETURN_IF_ERROR( + AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}) + .status()); + + // The outfeed instruction produces a token. However, existing users expect + // a nil shape (empty tuple). This should only be relevant if the outfeed is + // the root of a computation. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto tuple_instr; + *tuple_instr.mutable_shape() = ShapeUtil::MakeNil(); + + // The dummy tuple should have no sharding. + { + XlaScopedShardingAssignment scoped_sharding(this, OpSharding()); + TF_ASSIGN_OR_RETURN( + XlaOp empty_tuple, + AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {})); + return empty_tuple; + } }); } XlaOp XlaBuilder::CustomCall(const string& call_target_name, tensorflow::gtl::ArraySlice<XlaOp> operands, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; if (tensorflow::str_util::StartsWith(call_target_name, "$")) { return InvalidArgument( @@ -1066,7 +1183,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands, const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; *instr.mutable_shape() = shape; instr.set_channel_name(channel_name); @@ -1221,7 +1338,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { XlaOp XlaBuilder::Transpose(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> permutation) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1236,7 +1353,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, XlaOp XlaBuilder::Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1265,7 +1382,7 @@ XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1277,7 +1394,7 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1311,13 +1428,12 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands, const XlaComputation& computation, tensorflow::gtl::ArraySlice<int64> dimensions, tensorflow::gtl::ArraySlice<XlaOp> static_operands) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { if (!static_operands.empty()) { return Unimplemented("static_operands is not supported in Map"); } HloInstructionProto instr; - std::vector<const Shape*> operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), @@ -1329,16 +1445,32 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands, ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape, dimensions)); + const Shape& output_shape = instr.shape(); + const int64 output_rank = ShapeUtil::Rank(output_shape); AddCalledComputation(computation, &instr); + std::vector<XlaOp> new_operands(operands.begin(), operands.end()); + for (XlaOp& new_operand : new_operands) { + TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand)); + const int64 rank = ShapeUtil::Rank(shape); + if (rank != output_rank) { + TF_ASSIGN_OR_RETURN(new_operand, + InDimBroadcast(output_shape, new_operand, {})); + TF_ASSIGN_OR_RETURN(shape, GetShape(new_operand)); + } + if (!ShapeUtil::SameDimensions(output_shape, shape)) { + TF_ASSIGN_OR_RETURN(new_operand, + AddBroadcastSequence(output_shape, new_operand)); + } + } - return AddInstruction(std::move(instr), HloOpcode::kMap, operands); + return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands); }); } XlaOp XlaBuilder::RngOp(RandomDistribution distribution, tensorflow::gtl::ArraySlice<XlaOp> parameters, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; // Check the number of parameters per RNG distribution. @@ -1376,7 +1508,7 @@ XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b, XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, const XlaOp& init) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; // Infer shape. @@ -1398,7 +1530,7 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); @@ -1423,7 +1555,7 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate)); @@ -1455,7 +1587,7 @@ XlaOp XlaBuilder::Reduce( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1480,7 +1612,7 @@ XlaOp XlaBuilder::Reduce( XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); std::vector<int64> all_dimnos(ShapeUtil::Rank(operand_shape)); std::iota(all_dimnos.begin(), all_dimnos.end(), 0); @@ -1493,7 +1625,7 @@ XlaOp XlaBuilder::ReduceWindow( const XlaComputation& computation, tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1516,7 +1648,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1540,7 +1672,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1563,7 +1695,7 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, const XlaOp& mean, const XlaOp& variance, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1588,7 +1720,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& batch_mean, const XlaOp& batch_var, const XlaOp& grad_output, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1612,7 +1744,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> replica_group_ids) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); auto b = CreateSubBuilder("sum"); @@ -1628,7 +1760,7 @@ XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, tensorflow::gtl::ArraySlice<int64> replica_group_ids, const tensorflow::gtl::optional<ChannelHandle>& channel_id) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { if (channel_id.has_value()) { return Unimplemented("channel_id is not supported in AllReduce"); } @@ -1655,7 +1787,7 @@ XlaOp XlaBuilder::SelectAndScatter( tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); return SelectAndScatterWithGeneralPadding( operand, select, window_dimensions, window_strides, @@ -1672,7 +1804,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1700,7 +1832,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), @@ -1714,7 +1846,7 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, } void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { - NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; // Send instruction produces a tuple of {aliased operand, U32 context}. @@ -1735,7 +1867,7 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { } XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { - return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; // Recv instruction produces a tuple of {receive buffer, U32 context}. diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index f18306fff0..d7e50772c4 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -18,6 +18,7 @@ limitations under the License. #include <map> #include <string> +#include <type_traits> #include <utility> #include "tensorflow/compiler/xla/client/padding.h" @@ -46,22 +47,25 @@ class XlaBuilder; // instruction as an operand. class XlaOp { public: - XlaOp() : handle_(-1), builder_(nullptr) {} - ~XlaOp() {} - - XlaBuilder* builder() const { return builder_; } - - bool operator==(const XlaOp& rhs) const { - return handle_ == rhs.handle_ && builder_ == rhs.builder_; + XlaOp() : handle_(-1), builder_(nullptr) { + static_assert(std::is_trivially_destructible<XlaOp>::value, + "XlaOp should be trivially destructible"); } + ~XlaOp() = default; - bool operator!=(const XlaOp& rhs) const { - return handle_ != rhs.handle_ || builder_ != rhs.builder_; - } + XlaBuilder* builder() const { return builder_; } // Returns true if the XlaOp represents valid, non-erroneous value. bool valid() const { return handle_ >= 0; } + // Returns true if the XlaOp was created by the XlaOp() constructor and + // not returned by a builder. + bool IsUninitialized() const { return builder_ == nullptr; } + + bool IsIdenticalTo(const XlaOp& rhs) const { + return handle_ == rhs.handle_ && builder_ == rhs.builder_; + } + friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) { out << op.handle(); return out; @@ -84,6 +88,30 @@ class XlaOp { XlaBuilder* builder_; }; +// Arithmetic operator overloads for the XlaOp type. +XlaOp operator-(const XlaOp& x); +XlaOp operator+(const XlaOp& x, const XlaOp& y); +XlaOp operator-(const XlaOp& x, const XlaOp& y); +XlaOp operator*(const XlaOp& x, const XlaOp& y); +XlaOp operator/(const XlaOp& x, const XlaOp& y); +XlaOp operator%(const XlaOp& x, const XlaOp& y); + +// Bitwise operator overloads for the XlaOp type. +XlaOp operator~(const XlaOp& x); +XlaOp operator&(const XlaOp& x, const XlaOp& y); +XlaOp operator|(const XlaOp& x, const XlaOp& y); +XlaOp operator^(const XlaOp& x, const XlaOp& y); +XlaOp operator<<(const XlaOp& x, const XlaOp& y); +// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs +// a right logical shift. +XlaOp operator>>(const XlaOp& x, const XlaOp& y); + +// We don't overload the relational operators (==, !=, <, <=, >, >=) because the +// semantics might be surprising since their result types are usually 'bool'. +// Further programmers may expect == to be a structural equality. +// We also choose not to overload any of the mutating operators (e.g., +=, -=) +// because the semantics might be misleading — XLA computations are immutable. + // A convenient interface for building up computations. // // Thread-compatible. @@ -822,6 +850,24 @@ class XlaBuilder { // Returns the (inferred) result for the current computation's shape. StatusOr<ProgramShape> GetProgramShape() const; + // Reports an error to the builder, by + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to + // Build()/GetShape()/...) + // * dying if die_immediately_on_error_ is true. + // Returns an XlaOp with an invalid handle but a valid builder. This value can + // be returned in place of a value in APIs that return an XlaOp. + XlaOp ReportError(const Status& error); + + // A helper function that converts a StatusOr<XlaOp> into an XlaOp. + // If the Status was an error, reports the error to builder and returns an + // invalid XlaOp handle. + XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op); + + // A helper function that runs a function that returns a StatusOr<XlaOp> and + // returns an XlaOp. + XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator); + private: StatusOr<XlaOp> AddInstruction( HloInstructionProto&& instr, HloOpcode opcode, @@ -830,14 +876,6 @@ class XlaBuilder { void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); - // Notes that the error occurred by: - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to Build()) - // * dying if die_immediately_on_error_ is true - void NoteError(const Status& error); - - XlaOp NoteErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator); - StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const; // Internal helper method that does the building for an arbitrary unary op. diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc index 0680b38f3a..8a5bf96714 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -59,6 +59,76 @@ TEST_F(XlaBuilderTest, OnePlusTwo) { EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); } +TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) { + auto test_unary_operator = + [&](std::function<XlaOp(XlaOp)> op, + ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) { + XlaBuilder b(TestName()); + op(b.ConstantR0<int32>(1)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant())); + test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant())); +} + +TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { + auto test_binary_operator = + [&](std::function<XlaOp(XlaOp, XlaOp)> op, + ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) { + XlaBuilder b(TestName()); + op(b.ConstantR0<int32>(1), b.ConstantR0<int32>(2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + + test_binary_operator([](XlaOp x, XlaOp y) { return x + y; }, + op::Add(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x - y; }, + op::Subtract(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x * y; }, + op::Multiply(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x / y; }, + op::Divide(op::Constant(), op::Constant())); + + test_binary_operator([](XlaOp x, XlaOp y) { return x & y; }, + op::And(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x | y; }, + op::Or(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; }, + op::Xor(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x << y; }, + op::ShiftLeft(op::Constant(), op::Constant())); + test_binary_operator( + [](XlaOp x, XlaOp y) { return x >> y; }, + op::ShiftRightArithmetic(op::Constant(), op::Constant())); + + auto test_unsigned_binary_operator = + [&](std::function<XlaOp(XlaOp, XlaOp)> op, + ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) { + XlaBuilder b(TestName()); + op(b.ConstantR0<uint32>(1), b.ConstantR0<uint32>(2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + test_unsigned_binary_operator( + [](XlaOp x, XlaOp y) { return x >> y; }, + op::ShiftRightLogical(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { + XlaBuilder b(TestName()); + b.ConstantR0<float>(1) >> b.ConstantR0<float>(2); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Argument to >> operator does not have an integral type")); +} + TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { XlaBuilder b(TestName()); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); @@ -221,5 +291,32 @@ TEST_F(XlaBuilderTest, Transpose) { EXPECT_THAT(root, op::Transpose(op::Parameter())); } +TEST_F(XlaBuilderTest, ReportError) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + b.Add(b.ReportError(InvalidArgument("a test error")), x); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); +} + +TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) { + XlaBuilder b(TestName()); + StatusOr<XlaOp> op(b.ConstantR0<float>(1.0)); + b.Add(b.ReportErrorOrReturn(op), b.ConstantR0<float>(2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { + XlaBuilder b(TestName()); + StatusOr<XlaOp> op(InvalidArgument("a test error")); + b.Add(b.ReportErrorOrReturn(op), b.ConstantR0<float>(2.0)); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 3f059cac30..15eeb2ea13 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -248,6 +248,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } + if (layout.format() == SPARSE) { + if (!layout.padded_dimensions().empty()) { + return InvalidArgument("Sparse layout has padded dimensions"); + } + } + return Status::OK(); } diff --git a/tensorflow/compiler/xla/overflow_util.h b/tensorflow/compiler/xla/overflow_util.h new file mode 100644 index 0000000000..8657d3a4bf --- /dev/null +++ b/tensorflow/compiler/xla/overflow_util.h @@ -0,0 +1,50 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Multiply two nonnegative int64's, returning negative for overflow +inline int64 MultiplyWithoutOverflow(const int64 x, const int64 y) { + // Multiply in uint64 rather than int64 since signed overflow is undefined. + // Negative values will wrap around to large unsigned values in the casts + // (see section 4.7 [conv.integral] of the C++14 standard). + const uint64 ux = x; + const uint64 uy = y; + const uint64 uxy = ux * uy; + + // Check if we overflow uint64, using a cheap check if both inputs are small + if (TF_PREDICT_FALSE((ux | uy) >> 32 != 0)) { + // Ensure nonnegativity. Note that negative numbers will appear "large" + // to the unsigned comparisons above. + CHECK(x >= 0 && y >= 0); + + // Otherwise, detect overflow using a division + if (ux != 0 && uxy / ux != uy) return -1; + } + + // Cast back to signed. Any negative value will signal an error. + return static_cast<int64>(uxy); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 29062348b0..734d9334fd 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -511,22 +511,14 @@ LocalOp LocalComputationBuilder::Rev( LocalOp LocalComputationBuilder::Map( tensorflow::gtl::ArraySlice<LocalOp> operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::gtl::ArraySlice<LocalOp> static_operands) { + tensorflow::gtl::ArraySlice<int64> dimensions) { std::vector<XlaOp> xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { xla_ops.push_back(op.op()); } - std::vector<XlaOp> static_xla_ops; - static_xla_ops.reserve(static_operands.size()); - for (const auto& op : static_operands) { - static_xla_ops.push_back(op.op()); - } - - return builder_.Map(xla_ops, local_computation.computation(), dimensions, - static_xla_ops); + return builder_.Map(xla_ops, local_computation.computation(), dimensions); } LocalOp LocalComputationBuilder::Reduce( @@ -621,6 +613,7 @@ _FORWARD_BINOP(Max) _FORWARD_BINOP(Min) _FORWARD_BINOP(And) _FORWARD_BINOP(Or) +_FORWARD_BINOP(Xor) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 95f0a0610b..e920f8aecd 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -270,8 +270,7 @@ class LocalComputationBuilder { LocalOp Map(tensorflow::gtl::ArraySlice<LocalOp> operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::gtl::ArraySlice<LocalOp> static_operands); + tensorflow::gtl::ArraySlice<int64> dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, @@ -333,6 +332,7 @@ class LocalComputationBuilder { _FORWARD_BINOP(Min) _FORWARD_BINOP(And) _FORWARD_BINOP(Or) + _FORWARD_BINOP(Xor) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 477df6fde2..76e9e637cd 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -988,6 +988,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Min; %unignore xla::swig::LocalComputationBuilder::And; %unignore xla::swig::LocalComputationBuilder::Or; +%unignore xla::swig::LocalComputationBuilder::Xor; %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index a1fc25303c..abb97d0c6f 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -123,6 +123,7 @@ _BINARY_OPS = [ 'Min', 'And', 'Or', + 'Xor', 'Pow', ] @@ -908,20 +909,19 @@ class ComputationBuilder(object): """ return self._client.Call(computation_to_apply.c_local_computation, operands) - def Map(self, operands, computation_to_apply, dimensions, static_operands=()): + def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. Args: operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. - static_operands: auxiliary arguments passed to the applied computation. Returns: A LocalOp representing the added Map op. """ return self._client.Map(operands, computation_to_apply.c_local_computation, - dimensions, static_operands) + dimensions) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 71e1d60a4e..0564ddcb85 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -157,6 +157,13 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayBool([True, True, False, False]))) self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) + def testBooleanXor(self): + c = self._NewComputation() + c.Xor( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + def testSum2DF32(self): c = self._NewComputation() c.Add( @@ -1168,14 +1175,6 @@ class EmbeddedComputationsTest(LocalComputationTest): self._CreateBinaryDivF64Computation(), [0]) self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) - def DISABLED_testMapWithStaticOperands(self): - c = self._NewComputation() - factor = c.ConstantF32Scalar(3.0) - c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], - self._CreateMulF32ByParamComputation(), [0], - static_operands=[factor]) - self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0]) - def testSelectAndScatterF32(self): c = self._NewComputation() c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index d7dd9786a2..4031320001 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -91,7 +91,7 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { auto y = builder.ConstantR1<float>( {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); std::vector<float> expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index c08960a57b..0833289b73 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2094,6 +2094,7 @@ cc_library( hdrs = ["hlo_verifier.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index d8a9aba834..4858fe61e0 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -50,20 +50,15 @@ namespace { namespace m = match; -// Returns whether operand is a literal with the given value. -bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { - return operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAll(value); -} - bool IsAll(const HloInstruction* op, int8 value) { - if (IsLiteralWithValue(op, value)) { - return true; - } - if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) { - return true; + switch (op->opcode()) { + case HloOpcode::kBroadcast: + return IsAll(op->operand(0), value); + case HloOpcode::kConstant: + return op->literal().IsAll(value); + default: + return false; } - return false; } // Returns whether the given transpose produces a result which is bit-wise @@ -160,9 +155,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleMap(HloInstruction* map) override; - Status HandleMaximum(HloInstruction* maximum) override; - Status HandleMinimum(HloInstruction* minimum) override; - // Returns whether algebraic simplification has occurred. const bool changed() const { return changed_; } @@ -201,8 +193,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Helper method to perform and add reduction in a single dimension. HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { - HloInstruction* zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction* zero = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::Zero(hlo->shape().element_type()).CloneToUnique())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -572,6 +565,14 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { return Status::OK(); } +namespace { +template <typename T> +Status InvertConstant(const HloInstruction& constant, Literal* result) { + return result->Populate<T>([&](tensorflow::gtl::ArraySlice<int64> indices) { + return T{1.0} / constant.literal().Get<T>(indices); + }); +} +} // namespace Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { Shape* shape; @@ -633,14 +634,31 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // (Backends can do this transformation, but generally only if the constant is // a scalar.) if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { - HloInstruction* one = - computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(a->shape().element_type()).CloneToUnique())); - HloInstruction* inverse = computation_->AddInstruction( - HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kMultiply, a, inverse)); + Literal new_literal(b->shape()); + switch (b->shape().element_type()) { + case F16: + TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal)); + break; + case F32: + TF_RETURN_IF_ERROR(InvertConstant<float>(*b, &new_literal)); + break; + case BF16: + TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*b, &new_literal)); + break; + case F64: + TF_RETURN_IF_ERROR(InvertConstant<double>(*b, &new_literal)); + break; + case C64: + TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal)); + break; + default: + return Status::OK(); + } + auto inverse = computation_->AddInstruction( + HloInstruction::CreateConstant((new_literal.CloneToUnique()))); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); + return ReplaceInstruction(divide, new_divide); } // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) @@ -660,18 +678,18 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply, b, c)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kDivide, a, b_times_c)); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c)); + return ReplaceInstruction(divide, new_divide); } // A / (B / C) => (A*C) / B if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, a, c)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kDivide, a_times_c, b)); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b)); + return ReplaceInstruction(divide, new_divide); } return Status::OK(); @@ -2074,10 +2092,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( convolution, HloInstruction::CreateBroadcast( convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(convolution->shape().element_type(), {}), - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))), + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::Zero(convolution->shape().element_type()) + .CloneToUnique())), {})); } const auto& window = convolution->window(); @@ -2249,68 +2266,6 @@ Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { return ReplaceWithNewInstruction(map, std::move(clone)); } -Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { - // Match the following tree: - // min_operand operand - // \ / - // max_operand min - // \ / - // max - // where max_operand and min_operand are scalar constants. - { - HloInstruction* min; - HloInstruction* max_operand; - HloInstruction* min_operand; - HloInstruction* operand; - - if (hlo_query::MatchBinaryInstructionOperandOpcode( - HloOpcode::kMinimum, maximum, - /*matching_operand=*/&min, - /*other_operand=*/&max_operand) && - hlo_query::MatchBinaryInstructionOperand( - hlo_query::IsScalarConstant, min, - /*matching_operand=*/&min_operand, - /*other_operand=*/&operand) && - TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum, - max_operand)) { - return Status::OK(); - } - } - - return Status::OK(); -} - -Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { - // Match the following tree: - // max_operand operand - // \ / - // min_operand max - // \ / - // min - // where max_operand and min_operand are scalar constants. - { - HloInstruction* max; - HloInstruction* max_operand; - HloInstruction* min_operand; - HloInstruction* operand; - - if (hlo_query::MatchBinaryInstructionOperandOpcode( - HloOpcode::kMaximum, minimum, - /*matching_operand=*/&max, - /*other_operand=*/&min_operand) && - hlo_query::MatchBinaryInstructionOperand( - hlo_query::IsScalarConstant, max, - /*matching_operand=*/&max_operand, - /*other_operand=*/&operand) && - TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max, - max_operand)) { - return Status::OK(); - } - } - - return Status::OK(); -} - StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 49cc0b808b..b733f6f59e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -201,8 +201,11 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); - builder.AddInstruction( - HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation)); + builder.AddInstruction(HloInstruction::CreateMap( + r2f32, + {param0, builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, zero, {}))}, + add_computation)); auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -211,7 +214,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, zero)); + EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { @@ -367,17 +370,16 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r2f32, "param1")); HloInstruction* param2 = builder.AddInstruction( HloInstruction::CreateParameter(2, r2f32, "param2")); HloInstruction* param3 = builder.AddInstruction( - HloInstruction::CreateParameter(3, r0f32, "param3")); + HloInstruction::CreateParameter(3, r2f32, "param3")); HloInstruction* div0 = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1)); HloInstruction* div1 = builder.AddInstruction( @@ -398,8 +400,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { EXPECT_THAT( computation->root_instruction(), op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); - EXPECT_TRUE( - ShapeUtil::Compatible(computation->root_instruction()->shape(), r2f32)); } // Test that A/exp(B) is simplified to A*exp(-B). @@ -459,7 +459,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { // Test that broadcasting is done on the right step when simplifying A/pow(B,C) // to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -467,7 +466,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* param2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction::CreateParameter(2, r1f32, "param2")); HloInstruction* power = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2)); builder.AddInstruction( @@ -484,14 +483,9 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { ASSERT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); - - const HloInstruction* negate = - computation->root_instruction()->operand(1)->operand(1); - const Shape& negate_shape = negate->shape(); - EXPECT_EQ(0, negate_shape.dimensions_size()); } -// A / Const => A * (1 / Const) +// A / Const => A * InvertedConst TEST_F(AlgebraicSimplifierTest, DivideByConstant) { Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); HloComputation::Builder builder(TestName()); @@ -510,20 +504,19 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Divide(op::Constant(), constant))); + op::Multiply(param0, op::Constant())); } // pow(pow(A, X), Y) => pow(A, X*Y) TEST_F(AlgebraicSimplifierTest, PowerOfPower) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* exp1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* exp2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction::CreateParameter(2, r1f32, "param2")); HloInstruction* inner_power = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, @@ -540,15 +533,14 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex // numbers. TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { - Shape r0c64 = ShapeUtil::MakeShape(C64, {}); Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( HloInstruction::CreateParameter(0, r1c64, "param0")); HloInstruction* exp1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0c64, "param1")); + HloInstruction::CreateParameter(1, r1c64, "param1")); HloInstruction* exp2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0c64, "param2")); + HloInstruction::CreateParameter(2, r1c64, "param2")); HloInstruction* inner_power = builder.AddInstruction( HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1)); builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, @@ -1416,33 +1408,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); } -// Regression test for a bug in the reshape sinking transformation, where -// moving a reshape to a scalar led to a crash. -TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { - HloComputation::Builder builder(TestName()); - HloInstruction* param = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1}), "param")); - HloInstruction* reshape = builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param)); - HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.}))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero)); - auto computation = module().AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - - simplifier.Run(&module()).ValueOrDie(); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); -} - // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { @@ -2103,160 +2068,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { EXPECT_EQ("NO_CHANGE", build_and_simplify()); } -// Test that max(min(A, x), y) is transformed to clamp(y, A, x) -TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); - HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMinimum, param0, min_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Minimum(param0, min_value), max_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar -// values. -TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for -// broadcasted scalar values. -TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, non-constant1), non-constant2) is not canonicalized to -// clamp(non-constant1, A, non-constant2) -TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "param1")); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); -} - -// Test that min(f(max(A, constant1)), constant2) is not transformed to -// clamp(constant1, A, constant2) -TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - HloInstruction* fmax = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value)); - builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMinimum, fmax, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), - min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), - min_value)); -} - // Test that slice(broadcast(/*scalar value*/)) simplifies to a single // broadcast. TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index efa4696130..28b5a5784f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1874,11 +1874,15 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); - auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto infeed = + builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, "")); + auto infeed_data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r0s32, infeed, 0)); auto cond0 = module->AddEmbeddedComputation(build_cond()); auto body0 = module->AddEmbeddedComputation(build_body()); auto while0 = builder.AddInstruction( - HloInstruction::CreateWhile(r0s32, cond0, body0, infeed)); + HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data)); auto cond1 = module->AddEmbeddedComputation(build_cond()); auto body1 = module->AddEmbeddedComputation(build_body()); @@ -1909,8 +1913,8 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. SequentialHloOrdering::HloModuleSequence sequence; - sequence[module->entry_computation()] = {infeed, while0, while1, zero, - add, while2, tuple}; + sequence[module->entry_computation()] = { + token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; TF_ASSERT_OK_AND_ASSIGN( auto assignment, BufferAssigner::Run( diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 738d00881d..924348c870 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -148,14 +148,16 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { HloComputation::Builder outfeeder(TestName() + ".outfeeder"); auto value = outfeeder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); + auto token = outfeeder.AddInstruction(HloInstruction::CreateAfterAll({})); outfeeder.AddInstruction( - HloInstruction::CreateOutfeed(f32, value, /*outfeed_config=*/"")); + HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/"")); auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build()); HloComputation::Builder outer(TestName() + ".outer"); outer.AddInstruction(HloInstruction::CreateCall( - ShapeUtil::MakeNil(), /*operands=*/{}, outfeed_computation)); + outfeed_computation->root_instruction()->shape(), /*operands=*/{}, + outfeed_computation)); module->AddEntryComputation(outer.Build()); diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 868348547d..c38719d50e 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -144,8 +144,10 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* false_computation = conditional->false_computation(); - false_computation->AddInstruction( - HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + auto token = + false_computation->AddInstruction(HloInstruction::CreateAfterAll({})); + false_computation->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index ed1a50f516..e7539759ce 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1605,8 +1605,8 @@ HloModule TokensShouldNotBeCopied %constant.1 = s32[] constant(1) %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %generate-token = token[] generate-token(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token) + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) } %Cond (param: (s32[], token[])) -> pred[] { @@ -1619,7 +1619,7 @@ HloModule TokensShouldNotBeCopied ENTRY %TokensShouldNotBeCopied () -> s32[] { %one = s32[] constant(1) %negative_one = s32[] negate(%one) - %init_token = token[] generate-token() + %init_token = token[] after-all() %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token) %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index b703be0f39..2c3eb1ae36 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -54,29 +54,6 @@ cc_library( ) cc_library( - name = "external_constant_pool", - srcs = ["external_constant_pool.cc"], - hdrs = ["external_constant_pool.h"], - deps = [ - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "external_constant_pool_test", - srcs = ["external_constant_pool_test.cc"], - deps = [ - ":external_constant_pool", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - -cc_library( name = "cpu_compiler", srcs = ["cpu_compiler.cc"], hdrs = ["cpu_compiler.h"], @@ -175,7 +152,6 @@ cc_library( ":cpu_runtime", ":custom_call_target_registry", ":disassembler", - ":external_constant_pool", ":orc_jit_memory_mapper", ":runtime_fp16", ":runtime_conv2d", @@ -256,7 +232,6 @@ cc_library( ":cpu_options", ":cpu_runtime", ":dot_op_emitter", - ":external_constant_pool", ":ir_emission_utils", ":ir_function", ":parallel_loop_emitter", @@ -273,6 +248,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 52da9d6eac..55962ba70d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -269,6 +269,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, /*enable_dot_strength_reduction=*/false); + pass.AddPass<HloDCE>(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -306,11 +307,16 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, module->mutable_entry_computation_layout(), &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }, - /*enable_dot_strength_reduction=*/false); - pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true); + { + auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>( + "after layout assignement"); + pass.AddPass<HloPassFix<AlgebraicSimplifier>>( + /*is_layout_sensitive=*/true, + [](const Shape&, const Shape&) { return true; }, + /*enable_dot_strength_reduction=*/false); + pass.AddPass<HloDCE>(); + pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true); + } pipeline.AddPass<HloElementTypeConverter>(BF16, F32); // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = @@ -578,7 +584,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features, jit->external_constant_pool()); + &target_machine_features); for (auto embedded_computation : entry_computation->MakeEmbeddedComputationsList()) { @@ -765,8 +771,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, IrEmitter ir_emitter(*module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features, - /*external_constant_pool=*/nullptr); + &target_machine_features); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 97e10a89a2..750310c633 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -501,8 +501,8 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) { HloInstruction* exp = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); - builder.AddInstruction(HloInstruction::CreateMap( - shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{})); + builder.AddInstruction( + HloInstruction::CreateMap(shape, {exp}, CreateAdderToOne(module.get()))); module->AddEntryComputation(builder.Build()); @@ -525,8 +525,8 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) { HloInstruction* exp1 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1)); - builder.AddInstruction(HloInstruction::CreateMap( - shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{})); + builder.AddInstruction( + HloInstruction::CreateMap(shape, {exp0, exp1}, CreateMax(module.get()))); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc deleted file mode 100644 index c562865591..0000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" - -#include <algorithm> -#include <cstdlib> -#include <cstring> - -#include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" - -namespace xla { -namespace cpu { -void ExternalConstantPool::Insert(string name, const LiteralSlice& literal, - int64 alignment) { - CHECK(!ShapeUtil::IsTuple(literal.shape())); - CHECK(alignment > 0 && IsPowerOfTwo(static_cast<uint64>(alignment))); - CHECK(entries_.find(name) == entries_.end()); - - const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); - void* raw_pointer = tensorflow::port::AlignedMalloc( - literal_size, std::max<size_t>(alignment, sizeof(void*))); - CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size - << " bytes with alignment of " << alignment; - - std::memcpy(raw_pointer, literal.untyped_data(), literal_size); - entries_.emplace(std::move(name), static_cast<uint8*>(raw_pointer)); -} - -const uint8* ExternalConstantPool::Find(const string& name) { - auto it = entries_.find(name); - return it == entries_.end() ? nullptr : it->second.get(); -} -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h deleted file mode 100644 index 0677f5f0b5..0000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ - -#include <memory> - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/platform/mem.h" - -namespace xla { -namespace cpu { -// An ExternalConstantPool maintains a set of constants kept external to -// generated LLVM IR. These constants are accessed from the IR via globals with -// extern linkage. This current incarnation of ExternalConstantPool only -// supports the JIT CPU backend; the AOT backend is not supported. -// -// Implementation-wise, this is a simple wrapper around a map of strings to byte -// buffers. This simply implementation works in a JIT scenario. This class -// will have to become smarter if we decide to support external constant pools -// on AOT compiles in the future. -class ExternalConstantPool { - public: - // Inserts a buffer with the contents of `literal` into the constant pool with - // the name `name`. It is an error to try to insert two constants with the - // same `name` into the same constant pool. The buffer for literal is aligned - // to `aligment` bytes, and `alignment` must be a power of 2. - // - // The constant pool copies out the contents of `literal` into a buffer it - // owns -- it does not keep pointers to `literal`, or to memory owned by - // `literal`. - void Insert(string name, const LiteralSlice& literal, int64 alignment); - - // Find the constant with name `name` in this constant pool. If there isn't - // such constant, return nullptr. - const uint8* Find(const string& name); - - private: - // We need to `AlignedFree` pointers allocated into `entries_` since we - // allocate them with `AlignedMalloc`. - struct FreeDeleter { - void operator()(void* ptr) { tensorflow::port::AlignedFree(ptr); } - }; - - tensorflow::gtl::FlatMap<string, std::unique_ptr<uint8, FreeDeleter>> - entries_; -}; -} // namespace cpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc deleted file mode 100644 index 9290a4e5df..0000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace cpu { -namespace { -class ExternalConstantPoolTest : public ::testing::Test {}; - -template <typename T> -T GetFromBuffer(const uint8* buffer, int64 index) { - T result; - std::memcpy(&result, buffer + index * sizeof(T), sizeof(T)); - return result; -} - -TEST(ExternalConstantPoolTest, Basic) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); - constant_pool.Insert("name-0", *literal, 4); - const uint8* constant = constant_pool.Find("name-0"); - ASSERT_NE(constant, nullptr); - - EXPECT_EQ(GetFromBuffer<int32>(constant, 0), 1); - EXPECT_EQ(GetFromBuffer<int32>(constant, 1), 2); - EXPECT_EQ(GetFromBuffer<int32>(constant, 2), 3); - EXPECT_EQ(GetFromBuffer<int32>(constant, 3), 4); - - EXPECT_EQ(constant_pool.Find("name-1"), nullptr); -} - -TEST(ExternalConstantPoolTest, RowMinorLayout) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - const auto literal = Literal::CreateR2WithLayout( - {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); - constant_pool.Insert("name-0", *literal, 4); - const uint8* constant = constant_pool.Find("name-0"); - ASSERT_NE(constant, nullptr); - - EXPECT_EQ(GetFromBuffer<int32>(constant, 0), 1); - EXPECT_EQ(GetFromBuffer<int32>(constant, 1), 3); - EXPECT_EQ(GetFromBuffer<int32>(constant, 2), 2); - EXPECT_EQ(GetFromBuffer<int32>(constant, 3), 4); -} - -TEST(ExternalConstantPoolTest, Alignment) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - - for (int i = 0; i < 8; i++) { - int64 alignment = 1 << i; - string name = tensorflow::strings::StrCat("name-", i); - - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); - constant_pool.Insert(name, *literal, alignment); - - const uint8* constant = constant_pool.Find(name); - ASSERT_NE(constant, nullptr); - EXPECT_EQ(reinterpret_cast<intptr_t>(constant) % alignment, 0); - } -} - -} // namespace -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 75e8e9a835..6b66a4b0b7 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -48,6 +48,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -83,8 +85,7 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx, std::unordered_map<const HloComputation*, int64> computation_to_profile_idx, - const TargetMachineFeatures* target_machine_features, - ExternalConstantPool* external_constant_pool) + const TargetMachineFeatures* target_machine_features) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), @@ -94,8 +95,7 @@ IrEmitter::IrEmitter( alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), - target_machine_features_(*target_machine_features), - external_constant_pool_(external_constant_pool) { + target_machine_features_(*target_machine_features) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_enable_fast_math())); @@ -161,45 +161,18 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { } llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { - llvm::Constant* result; - - // We avoid creating large constants in the LLVM IR since LLVM is not - // efficient for large constant arrays. We still emit "small enough" constant - // arrays into the Ir, in the off chance the LLVM optimizer can do something - // interesting with it. - // - // TODO(b/29904935): Remove the large constant pool. - const int kMaxInternalConstantSizeInBytes = 128; - if (external_constant_pool_ && - ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { - string global_name = tensorflow::strings::StrCat( - "constant_global_", external_global_constant_counter_++); - llvm::GlobalVariable* result_global = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/IrShapeType(literal.shape()), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::ExternalLinkage, - /*Initializer=*/nullptr, - /*Name=*/AsStringRef(global_name)); - result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); - external_constant_pool_->Insert(global_name, literal, - MinimumAlignmentForShape(literal.shape())); - result = result_global; - } else { - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - llvm::GlobalVariable* result_global = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/initializer->getType(), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/initializer, - /*Name=*/""); - result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); - result = llvm::ConstantExpr::getBitCast( - result_global, IrShapeType(literal.shape())->getPointerTo()); - } - return result; + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/initializer->getType(), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/initializer, + /*Name=*/""); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + return llvm::ConstantExpr::getBitCast( + result_global, IrShapeType(literal.shape())->getPointerTo()); } Status IrEmitter::HandleConstant(HloInstruction* constant) { @@ -321,30 +294,42 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { return DefaultAction(select); } -Status IrEmitter::HandleInfeed(HloInstruction* infeed) { +Status IrEmitter::HandleInfeed(HloInstruction* instruction) { + HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction); VLOG(2) << "HandleInfeed: " << infeed->ToString(); - const Shape& shape = infeed->shape(); - - // The infeed operation produces data (dequeued from the infeed queue) at this - // address, which has been provided by buffer assignment. + // The infeed operation produces a two-element tuple containing data and a + // token value. HloInfeedInstruction::infeed_shape gives us the data shape. + const Shape& data_shape = infeed->infeed_shape(); + DCHECK(ShapeUtil::Equal(data_shape, + ShapeUtil::GetTupleElementShape(infeed->shape(), 0))); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed)); - llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed); - if (ShapeUtil::IsTuple(shape)) { - TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); + // Write the tuple index table. + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, + assignment_.GetUniqueSlice(infeed, {0})); + llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice, + assignment_.GetUniqueSlice(infeed, {1})); + llvm::Value* token_address = EmitTempBufferPointer( + token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); + llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, + &ir_builder_, module_); + + if (ShapeUtil::IsTuple(data_shape)) { + TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // For a tuple, we first copy each of the internal elements to // their corresponding target locations. We then construct the // tuple outer buffer containing pointers to the internal // elements. std::vector<llvm::Value*> tuple_element_addresses; - for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, - assignment_.GetUniqueSlice(infeed, {i})); + assignment_.GetUniqueSlice(infeed, {0, i})); const Shape& tuple_element_shape = - ShapeUtil::GetTupleElementShape(shape, i); + ShapeUtil::GetTupleElementShape(data_shape, i); // Only the outer tuple buffer's target address is obtained from // GetEmittedValueFor, to handle the case when Infeed is the root @@ -359,11 +344,11 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { tuple_element_addresses.push_back(tuple_element_address); } - llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_, - module_); + llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape), + tuple_element_addresses, &ir_builder_, module_); } else { - TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape, - GetEmittedValueFor(infeed))); + TF_RETURN_IF_ERROR( + EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address)); } return Status::OK(); @@ -2539,7 +2524,7 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { return Status::OK(); } -Status IrEmitter::HandleGenerateToken(HloInstruction* gen_token) { +Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0); // No code to generate, but we need to emit an address for book-keeping. TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token)); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index e1815c1db7..3c110a320f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -30,7 +30,6 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -67,17 +66,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. - // external_constant_pool: if non-null, points to an ExternalConstantPool - // instance into which the Ir emitter can spill - // constants. IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx, std::unordered_map<const HloComputation*, int64> computation_to_profile_idx, - const TargetMachineFeatures* target_machine, - ExternalConstantPool* external_constant_pool); + const TargetMachineFeatures* target_machine); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -150,7 +145,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; - Status HandleGenerateToken(HloInstruction* gen_token) override; + Status HandleAfterAll(HloInstruction* gen_token) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -537,9 +532,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { const TargetMachineFeatures& target_machine_features_; - int64 external_global_constant_counter_ = 0; - ExternalConstantPool* external_constant_pool_; - struct LiteralPtrHashFunctor { size_t operator()(const Literal* literal) const { return literal->Hash(); } }; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index fc2efbaf9a..36c9f74385 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -110,8 +110,9 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - infeed0 = u32[12345678,2]{1,0} infeed() - ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0) + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed() + infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 + ROOT outfeed0 = token[] outfeed(infeed0.data) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 167aa4adda..e3965b4e05 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -51,7 +51,7 @@ int main(int argc, char** argv) { xla::XlaBuilder builder(""); auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p1, p0, {0}); + builder.Add(p1, p0, {0}); xla::StatusOr<xla::XlaComputation> computation_status = builder.Build(); xla::XlaComputation computation = computation_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index c4c90515ac..be772cfb7e 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -127,13 +127,6 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, } llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { - if (const uint8* from_constant_pool = - external_constant_pool_.Find(string(name))) { - return llvm::JITEvaluatedSymbol( - reinterpret_cast<uint64_t>(from_constant_pool), - llvm::JITSymbolFlags::None); - } - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); if (func_addr == nullptr) { return nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 1851a3ee0b..d74b63fcf4 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -91,10 +90,6 @@ class SimpleOrcJIT { llvm::TargetMachine* target_machine() const { return target_machine_.get(); } - ExternalConstantPool* external_constant_pool() { - return &external_constant_pool_; - } - // Creates an llvm::TargetMachine suitable for JITting code that will run on // the current machine. static std::unique_ptr<llvm::TargetMachine> InferTargetMachineForJIT( @@ -112,7 +107,6 @@ class SimpleOrcJIT { std::shared_ptr<llvm::orc::SymbolResolver> symbol_resolver_; ObjLayerT object_layer_; CompileLayerT compile_layer_; - ExternalConstantPool external_constant_pool_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index 3a7255c1d2..1d4bf483ae 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -56,7 +56,8 @@ class CpuExternalConstantsTest : public CpuCodegenTest { TEST_F(CpuExternalConstantsTest, Basic) { TestWithArray(/*rows=*/1024, /*cols=*/1024, R"( -CHECK: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +CHECK-NOT: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +CHECK: @0 = private constant [4194304 x i8] {{.*}}, align 16 )"); } diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 23e7a3de4d..783b2820e9 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -96,8 +96,11 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil)); auto floor = builder.AddInstruction( HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp)); - auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); + auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( + vshape, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))), + {})); builder.AddInstruction( HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); @@ -114,9 +117,9 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); EXPECT_EQ(HloOpcode::kMultiply, fusion_instruction->fused_expression_root()->opcode()); - // There should be 7 fused instructions: 2 parameters and the fused + // There should be 8 fused instructions: 2 parameters and the fused // operations. - EXPECT_EQ(7, fusion_instruction->fused_instruction_count()); + EXPECT_EQ(8, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. auto result = ExecuteAndTransfer(std::move(module), {}); @@ -170,8 +173,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce)); auto floor = builder.AddInstruction( HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp)); - auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); + auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( + cshape, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))), + {})); builder.AddInstruction( HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor)); @@ -188,9 +194,9 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode()); EXPECT_EQ(HloOpcode::kMultiply, fusion_instruction1->fused_expression_root()->opcode()); - // There should be 5 fused instructions in the root fusion instruction: 2 + // There should be 6 fused instructions in the root fusion instruction: 2 // parameters, multiply, floor, and exp. - EXPECT_EQ(5, fusion_instruction1->fused_instruction_count()) + EXPECT_EQ(6, fusion_instruction1->fused_instruction_count()) << fusion_instruction1->fused_instructions_computation()->ToString(); auto fusion_instruction2 = reduce->operand(0); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 1739b6e8b7..90b99c828e 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -38,7 +38,8 @@ while_body { while_cond { arg_cond = f32[2,3,2] parameter(0) - ROOT unknown = pred[] infeed() + infeed = (pred[], token[]) infeed() + ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { @@ -49,8 +50,8 @@ ENTRY main { {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - out0 = () outfeed(f32[2,3,2] const_a) - ROOT out1 = () outfeed(f32[2,3,2] const_b) + out0 = token[] outfeed(f32[2,3,2] const_a) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b) } )"; @@ -84,7 +85,8 @@ while_body { while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - ROOT unknown = pred[] infeed() + infeed = (pred[], token[]) infeed() + ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index 40b4d0ed00..dac416e1c7 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -32,7 +32,8 @@ ENTRY main { {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - ROOT out = () outfeed(f32[2,3,2] const_a) + outfeed = token[] outfeed(f32[2,3,2] const_a) + ROOT root = () tuple() } )"; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 7d56d57b5f..cb3676c5ba 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -246,7 +246,7 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; - virtual Status HandleGenerateToken(HloInstructionPtr token) = 0; + virtual Status HandleAfterAll(HloInstructionPtr token) = 0; // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 6934e00a4b..987c91e5ba 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -188,7 +188,7 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } - Status HandleGenerateToken(HloInstructionPtr token) override { + Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index af6d298589..2508755e4c 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -442,6 +442,7 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:multi_output_fusion", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc index db6924c742..c77e3c81c9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -126,12 +126,17 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { HloInstruction* variance_plus_epsilon = computation_->AddInstruction(HloInstruction::CreateBinary( inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev, - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(-2))))); + computation_->AddInstruction(HloInstruction::CreateBroadcast( + inverse_stddev->shape(), + computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(-2))), + {})))); HloInstruction* variance = computation_->AddInstruction(HloInstruction::CreateBinary( variance_plus_epsilon->shape(), HloOpcode::kSubtract, - variance_plus_epsilon, epsilon)); + variance_plus_epsilon, + computation_->AddInstruction(HloInstruction::CreateBroadcast( + variance_plus_epsilon->shape(), epsilon, {})))); // Repackage the results. std::unique_ptr<HloInstruction> new_tuple = HloInstruction::CreateTuple({ @@ -175,12 +180,17 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { HloInstruction* var_plus_epsilon = computation_->AddInstruction(HloInstruction::CreateBinary( batch_norm->operand(3)->shape(), HloOpcode::kAdd, - batch_norm->mutable_operand(3), epsilon)); + batch_norm->mutable_operand(3), + computation_->AddInstruction(HloInstruction::CreateBroadcast( + batch_norm->operand(3)->shape(), epsilon, {})))); HloInstruction* inverse_stddev = computation_->AddInstruction(HloInstruction::CreateBinary( var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon, - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(-.5))))); + computation_->AddInstruction(HloInstruction::CreateBroadcast( + var_plus_epsilon->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0<float>(-.5))), + {})))); std::vector<HloInstruction*> operands(batch_norm->operands().begin(), batch_norm->operands().end()); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index ea34d5b30c..2b63d8727c 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -22,29 +22,29 @@ namespace xla { namespace gpu { InfeedThunk::InfeedThunk( - tensorflow::gtl::ArraySlice<BufferAllocation::Slice> tuple_element_buffers, - const BufferAllocation::Slice& destination_buffer, + const ShapeTree<BufferAllocation::Slice>& infeed_slices, const HloInstruction* hlo_instruction) - : Thunk(Kind::kInfeed, hlo_instruction), - tuple_element_buffers_(tuple_element_buffers.begin(), - tuple_element_buffers.end()), - destination_buffer_(destination_buffer) {} + : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream) { VLOG(2) << "Infeeding to GPU "; - se::DeviceMemoryBase destination_address = - buffer_allocations.GetDeviceAddress(destination_buffer_); - + // First copy the infeed data which is element 0 of the infeed instruction's + // two-tuple output (the other element is a token). + se::DeviceMemoryBase data_address = + buffer_allocations.GetDeviceAddress(infeed_slices_.element({0})); InfeedManager* infeed_manager = GetOrCreateInfeedManager(); std::vector<InfeedBuffer*> infeed_buffers; - if (ShapeUtil::IsTuple(hlo_instruction()->shape())) { - CHECK(!ShapeUtil::IsNestedTuple(hlo_instruction()->shape())); + const Shape& data_shape = + ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0); + if (ShapeUtil::IsTuple(data_shape)) { + CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // Transfer the tuple elements first. std::vector<void*> tuple_element_addresses; - for (BufferAllocation::Slice tuple_element_buffer : - tuple_element_buffers_) { + for (int i = 0; i < ShapeUtil::TupleElementCount(data_shape); ++i) { + const BufferAllocation::Slice& tuple_element_buffer = + infeed_slices_.element({0, i}); se::DeviceMemoryBase tuple_element_address = buffer_allocations.GetDeviceAddress(tuple_element_buffer); @@ -56,15 +56,23 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, } // Transfer the tuple outer buffer. auto host_size = tuple_element_addresses.size() * sizeof(void*); - stream->ThenMemcpy(&destination_address, tuple_element_addresses.data(), + stream->ThenMemcpy(&data_address, tuple_element_addresses.data(), host_size); } else { InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); infeed_buffers.push_back(buffer); - stream->ThenMemcpy(&destination_address, *(buffer->device_memory()), + stream->ThenMemcpy(&data_address, *(buffer->device_memory()), buffer->length()); } + // Construct top-level tuple of infeed containing the data and the token. Use + // a nullptr for the token, it should never be dereferenced. + std::vector<void*> infeed_addresses = {data_address.opaque(), nullptr}; + se::DeviceMemoryBase top_level_address = + buffer_allocations.GetDeviceAddress(infeed_slices_.element({})); + stream->ThenMemcpy(&top_level_address, infeed_addresses.data(), + 2 * sizeof(void*)); + Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 93713cb12d..cb9a6232f3 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -32,12 +32,8 @@ namespace gpu { class InfeedThunk : public Thunk { public: // Constructs a InfeedThunk that copies data from the on-device - // infeed queue to the device buffer - // `destination_buffer`. `mem_size` is the size of the data in - // bytes. - InfeedThunk(tensorflow::gtl::ArraySlice<BufferAllocation::Slice> - tuple_element_buffers, - const BufferAllocation::Slice& destination_buffer, + // infeed queue into the buffers in the given shape tree. + InfeedThunk(const ShapeTree<BufferAllocation::Slice>& infeed_slices, const HloInstruction* hlo_instruction); InfeedThunk(const InfeedThunk&) = delete; @@ -47,8 +43,7 @@ class InfeedThunk : public Thunk { se::Stream* stream) override; private: - const std::vector<BufferAllocation::Slice> tuple_element_buffers_; - const BufferAllocation::Slice destination_buffer_; + const ShapeTree<BufferAllocation::Slice> infeed_slices_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index efeb276470..d5e07c3afb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -191,6 +191,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( HloOpcode root_opcode = computation.root_instruction()->opcode(); PrimitiveType element_type = computation.root_instruction()->shape().element_type(); + bool is_atomic_integral = element_type == S32 || element_type == U32 || + element_type == S64 || element_type == U64; llvm::Value* source = ir_builder_.CreateLoad(source_address, "source"); if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. @@ -201,7 +203,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( {output_address->getType()}, &ir_builder_); return true; } - if (primitive_util::IsIntegralType(element_type)) { + if (is_atomic_integral) { // integral + integral ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, @@ -210,9 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( } } - // NVPTX supports atomicMax and atomicMin on only integer types. - if (root_opcode == HloOpcode::kMaximum && - primitive_util::IsIntegralType(element_type)) { + // NVPTX supports atomicMax and atomicMin only on integer types. + if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) { // max(integral, integral) auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Max @@ -222,8 +223,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( return true; } - if (root_opcode == HloOpcode::kMinimum && - primitive_util::IsIntegralType(element_type)) { + if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) { // min(integral, integral) auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Min diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f6f0a45124..fbd647f251 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -615,6 +615,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { output_shape_index = {i}; } if (inst->opcode() == HloOpcode::kReduce) { + CHECK(IsReductionToVector(*inst)) + << "Only reductions to vector are supported"; // Shapes, layouts and dimensions must be the same for all reduces // inside of this fusion. CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); @@ -1970,10 +1972,8 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { HloComputation* reducer = reduce->to_apply(); // HandleReduce specializes reduction from a multi-dimensional array to a 1D // array. The specialized version requires an initializer thunk that - // ingitializes the output array to the initial value of the reduce. - if (IsReductionToVector(*reduce) && - // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits - 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { + // initializes the output array to the initial value of the reduce. + if (IsReductionToVector(*reduce)) { TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk, BuildInitializerThunk(reduce)); std::vector<std::unique_ptr<Thunk>> thunks; @@ -2311,7 +2311,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } -Status IrEmitterUnnested::HandleGenerateToken(HloInstruction* gen_token) { +Status IrEmitterUnnested::HandleAfterAll(HloInstruction* gen_token) { return Status::OK(); } @@ -2563,17 +2563,14 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); - std::vector<BufferAllocation::Slice> tuple_element_buffers; - for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) { - BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment() - .GetUniqueSlice(inst, {i}) - .ConsumeValueOrDie(); - tuple_element_buffers.push_back(buffer); - } - - return MakeUnique<InfeedThunk>( - tuple_element_buffers, - /*destination_buffer=*/GetAllocationSlice(*inst), inst); + ShapeTree<BufferAllocation::Slice> slices(inst->shape()); + slices.ForEachMutableElement( + [this, inst](const ShapeIndex& index, BufferAllocation::Slice* slice) { + *slice = ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(inst, index) + .ConsumeValueOrDie(); + }); + return MakeUnique<InfeedThunk>(slices, inst); } namespace { @@ -2718,7 +2715,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( uint8 b = literal_bytes.front(); pattern16 = uint16{b} | (uint16{b} << 8); } else { - pattern16 = literal_bytes.front(); + memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); return {MakeUnique<Memset32BitValueThunk>( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 279a5c386a..819060061a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -76,7 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleGenerateToken(HloInstruction* gen_token) override; + Status HandleAfterAll(HloInstruction* gen_token) override; Status EmitTargetElementLoop( const HloInstruction& hlo, diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index d541776f00..652b5c7687 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -23,9 +23,11 @@ limitations under the License. #include <string> #include <utility> +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -69,6 +71,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, // In that case, the operand of the reduce needs to have the same shape // as the other tuple operands, but also we need to compare the output // shapes of the reduces. + // TODO(tjoerg): Allow differences in fp precision. auto* element_instr_1 = get_element_instr(instr1); auto* element_instr_2 = get_element_instr(instr2); if (element_instr_1->opcode() == HloOpcode::kReduce && @@ -82,26 +85,33 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, } namespace { -bool IsReduction(HloInstruction* instr) { +bool IsInputFusibleReduction(HloInstruction* instr) { if (instr->IsMultiOutputFusion()) { for (const HloInstruction* operand : instr->fused_expression_root()->operands()) { if (operand->opcode() == HloOpcode::kReduce) { + CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput) + << " Reduce multi-output fusion " << instr->ToString() + << " must be an input fusion."; return true; } } return false; } else if (instr->opcode() == HloOpcode::kFusion) { - return instr->fused_expression_root()->opcode() == HloOpcode::kReduce; + // The loop emitter can handle to-vector reduce fusions. Such reduce + // fusions have the fusion kind kLoop rather than kInput. We do not fuse + // to-vector reduce fusions, because the resulting fusions may no longer be + // supported by loop emitter. + return IsReductionToVector(*instr->fused_expression_root()); } else { - return instr->opcode() == HloOpcode::kReduce; + return IsReductionToVector(*instr); } } } // namespace bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { // We can fuse reduces and loop fusions. - return IsReduction(instr) || + return IsInputFusibleReduction(instr) || (instr->opcode() == HloOpcode::kFusion && instr->fusion_kind() == HloInstruction::FusionKind::kLoop && // TODO(b/110202584): bitcasts make nested fusions, GPU has no support @@ -147,5 +157,110 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop; } +bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { + bool changed = false; + RecomputeReachability(); + + tensorflow::gtl::FlatSet<HloInstruction*> to_fuse; + // Keep a list of the instructions to fuse after making all the fusion + // decisions. We first aggressively add instructions to potential_fusion_list, + // then filter out instructions that will be no longer fusable because of + // reachability change. This avoids recalculating reachability on a large set + // of instructions. + std::vector<std::pair<HloInstruction*, HloInstruction*>> + potential_fusion_list; + std::vector<std::pair<HloInstruction*, HloInstruction*>> fusion_list; + std::vector<HloInstruction*> instrs_to_update_reachability; + + // For each reduce or reduce multi-output fusion, try to fuse it with loop + // fusions operands. + for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) { + if (consumer->user_count() == 0) { + continue; + } + if (!IsInputFusibleReduction(consumer)) { + continue; + } + + auto consumer_operands = consumer->operands(); + for (size_t i = 0; i < consumer_operands.size(); ++i) { + HloInstruction* producer = consumer_operands[i]; + if (!producer->IsFusable()) { + continue; + } + const bool is_loop_fusion = + producer->opcode() == HloOpcode::kFusion && + producer->fusion_kind() == HloInstruction::FusionKind::kLoop; + if (!is_loop_fusion) { + continue; + } + if (!ShapesCompatibleForFusion(producer, consumer)) { + continue; + } + // If we have already decided to fuse this producer, skip it. + if (ContainsKey(to_fuse, producer)) { + continue; + } + // Do not fuse a producer if the other operands of the fusion are + // reachable from the producer, this would create a cycle. + if (c_any_of(consumer_operands, [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { + break; + } + to_fuse.insert(producer); + potential_fusion_list.emplace_back(producer, consumer); + instrs_to_update_reachability.push_back(producer); + instrs_to_update_reachability.push_back(consumer); + break; + } + } + + // Filter out pairs that will be no longer fusable because of reachability + // change. + for (auto& fusion_pair : potential_fusion_list) { + HloInstruction* producer = fusion_pair.first; + HloInstruction* consumer = fusion_pair.second; + if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { + UpdateReachability(producer, consumer, instrs_to_update_reachability); + fusion_list.push_back(fusion_pair); + } + } + + for (auto fusions_to_create : fusion_list) { + HloInstruction* producer = fusions_to_create.first; + HloInstruction* consumer = fusions_to_create.second; + if (consumer->opcode() != HloOpcode::kFusion) { + // Fusing with a reduce (fusion) always results in an input fusion. + HloInstruction* input_fusion = + computation()->AddInstruction(HloInstruction::CreateFusion( + consumer->shape(), HloInstruction::FusionKind::kInput, consumer)); + VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " + << consumer->name() << " into " << input_fusion->name(); + TF_CHECK_OK(computation()->ReplaceInstruction(consumer, input_fusion)); + if (producer->opcode() == HloOpcode::kFusion) { + input_fusion->MergeFusionInstructionIntoMultiOutput(producer); + } else { + input_fusion->FuseInstructionIntoMultiOutput(producer); + } + } else { + VLOG(2) << "Fuse producer " << producer->name() << " into its consumer " + << consumer->name(); + + if (producer->opcode() == HloOpcode::kFusion) { + consumer->MergeFusionInstructionIntoMultiOutput(producer); + } else { + consumer->FuseInstructionIntoMultiOutput(producer); + } + } + changed = true; + } + return changed; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index 16db0e0f02..67ca5d49ee 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -45,6 +45,9 @@ class GpuMultiOutputFusion : public MultiOutputFusion { // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) override; + + // Fuse loop fusions into reduce fusions. + bool DoProducerConsumerMultiOutputFusion() override; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 5e7ceb7976..979ea79243 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -255,5 +255,99 @@ TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } +TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_add { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add + reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Add())); +} + +TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_select { + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} + greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + c1 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation + mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce + gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(), + op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Select())); +} + +TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_element_wise { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add_computation + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c057be8201..34b18b0e21 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -120,6 +120,30 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } +namespace { + +// Returns the new name for a fusion parameter when we change its number. +// +// Fusion parameters are named foo.param_1, bar.param_2, etc. We are +// renumbering the parameters, so replace the final number in the name with +// the updated value. +string RenameFusionParameter(const string& original_name, int64 new_param_no) { + const string param_underscore = ".param_"; + size_t index = original_name.rfind(param_underscore); + if (index == string::npos) { + return original_name; + } + string after_param = original_name.substr(index + param_underscore.size()); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + return StrCat(original_name.substr(0, index + param_underscore.size()), + new_param_no); + } + return original_name; +} + +} // namespace + Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -132,21 +156,8 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = param_instruction->name(); - // Fusion parameters are named foo.param_1, bar.param_2, etc. We are - // renumbering the parameters, so replace the final number in the name with - // the updated value. - const string param_underscore = ".param_"; - size_t index = param_name.rfind(param_underscore); - if (index == string::npos) { - string after_param = name().substr(index + param_underscore.size()); - int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { - param_name = - StrCat(param_name.substr(0, index), param_underscore, param_no); - } - } - + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( param_no, param_instruction->shape(), param_name)); @@ -159,6 +170,34 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } +Status HloComputation::RemoveUnusedParameters() { + CHECK(IsFusionComputation()); + int64 removed = 0; + for (int64 i = 0; i < param_instructions_.size(); ++i) { + HloInstruction* param_instruction = param_instructions_[i]; + if (param_instruction->user_count() == 0 && + param_instruction != root_instruction()) { + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + ++removed; + continue; + } + + if (removed > 0) { + const int64 param_no = i - removed; + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); + HloInstruction* new_instr = + AddInstructionInternal(HloInstruction::CreateParameter( + param_no, param_instruction->shape(), param_name)); + TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); + param_instructions_[param_no] = new_instr; + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + } + } + param_instructions_.resize(param_instructions_.size() - removed); + return Status::OK(); +} + bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 0f111a1a76..c1c3e79ebc 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -113,6 +113,11 @@ class HloComputation { // instruction. Status RemoveParameter(int64 param_no); + // Remove unused parameters from the computation. + // Note this is only applicatable to the computation for the fusion + // instruction. + Status RemoveUnusedParameters(); + // Add new parameter instruction to the computation. // This should be a new parameter. Instruction will be appended to parameters // and inserted to the instruction list. diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index c504fc51d2..a8f3f0e9c2 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -375,20 +375,20 @@ TEST_F(HloComputationTest, DeepCopyToken) { // Test that DeepCopyInstruction properly handles tokens which should not be // copied. auto builder = HloComputation::Builder(TestName()); - auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); // No copy should be added. - EXPECT_THAT(copy, op::GenerateToken()); + EXPECT_THAT(copy, op::AfterAll()); } TEST_F(HloComputationTest, DeepCopyTokenTuple) { // Test that DeepCopyInstruction properly handles tokens which should not be // copied. auto builder = HloComputation::Builder(TestName()); - auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); auto tuple = diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 762e1afc71..8955e26d5c 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -393,7 +393,7 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) { +Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 0d66736fe1..44e5df587c 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -97,7 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; - Status HandleGenerateToken(const HloInstruction* token) override; + Status HandleAfterAll(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index d22bef5673..f77e880a77 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -139,7 +139,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { XlaBuilder builder("matrix_multiply"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); - auto result = builder.Dot(lhs, rhs); + builder.Dot(lhs, rhs); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -160,7 +160,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { TEST_F(HloCostAnalysisTest, Map) { XlaBuilder builder("map"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); - auto result = builder.Map({input}, add_and_exp_, {0}); + builder.Map({input}, add_and_exp_, {0}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -186,7 +186,7 @@ TEST_F(HloCostAnalysisTest, Convolution) { ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, /*x_dim=*/3}), "kernel"); - auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid); + builder.Conv(input, kernel, {1, 1}, Padding::kValid); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -207,8 +207,7 @@ TEST_F(HloCostAnalysisTest, Reduce) { XlaBuilder builder("reduce"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); - auto result = - builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1}); + builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -225,8 +224,8 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { XlaBuilder builder("reduce_window"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); - auto result = builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_, - {4, 5}, {4, 5}, Padding::kValid); + builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_, {4, 5}, + {4, 5}, Padding::kValid); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -244,9 +243,8 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto source = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source"); - auto result = - builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, - source, builder.ConstantR0<float>(0), add_); + builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, + source, builder.ConstantR0<float>(0), add_); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -278,8 +276,8 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight"); auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias"); // sigmoid(input * weight + bias) - auto result = builder.Map( - {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1}); + builder.Map({builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, + {0, 1}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -421,7 +419,7 @@ TEST_F(HloCostAnalysisTest, TupleCost) { XlaBuilder builder("matmul"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); - auto tuple = builder.Tuple({x, y}); + builder.Tuple({x, y}); auto hlo_module = BuildHloGraph(&builder); ASSERT_IS_OK( @@ -446,10 +444,10 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { /*x_dim=*/3}), "kernel"); - auto result = builder.ConvGeneralDilated( - input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, - /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, - XlaBuilder::CreateDefaultConvDimensionNumbers(2)); + builder.ConvGeneralDilated(input, kernel, /*window_strides=*/{1, 1}, + /*padding=*/{{1, 1}, {1, 1}}, + /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, + XlaBuilder::CreateDefaultConvDimensionNumbers(2)); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -464,7 +462,7 @@ TEST_F(HloCostAnalysisTest, Slice) { // Test the analysis on a slice. XlaBuilder builder("slice"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); - auto slice = builder.Slice(x, {0}, {1}, {1}); + builder.Slice(x, {0}, {1}, {1}); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. @@ -479,7 +477,7 @@ TEST_F(HloCostAnalysisTest, DynamicSlice) { // Test the analysis on a slice. XlaBuilder builder("dynamic-slice"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); - auto slice = builder.DynamicSlice(x, builder.ConstantR1<int32>({1}), {1}); + builder.DynamicSlice(x, builder.ConstantR1<int32>({1}), {1}); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. @@ -494,8 +492,8 @@ TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { // Test the analysis on a slice. XlaBuilder builder("dynamic-update-slice"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); - auto slice = builder.DynamicUpdateSlice(x, builder.ConstantR1<float>({1.0}), - builder.ConstantR1<int32>({1})); + builder.DynamicUpdateSlice(x, builder.ConstantR1<float>({1.0}), + builder.ConstantR1<int32>({1})); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 5a56607a66..2822ecd788 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -234,9 +234,10 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { { auto param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); - - auto infeed = - body_builder.AddInstruction(HloInstruction::CreateInfeed(shape, "")); + auto token = + body_builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto infeed = body_builder.AddInstruction( + HloInstruction::CreateInfeed(shape, token, "")); body_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed)); } @@ -278,8 +279,10 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { { auto param = nested_callee_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); + auto token = nested_callee_builder.AddInstruction( + HloInstruction::CreateAfterAll({})); nested_callee_builder.AddInstruction( - HloInstruction::CreateOutfeed(shape, param, "")); + HloInstruction::CreateOutfeed(shape, param, token, "")); } auto nested_called_computation = module->AddEmbeddedComputation(nested_callee_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 5d8081c1ef..ff356bdd6d 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -340,10 +340,12 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { HloModule Module ENTRY entry { - infeed = (f32[4], f32[4]) infeed(), - sharding={{maximal device=1}, {maximal device=0}} - gte0 = f32[4] get-tuple-element(infeed), index=0 - gte1 = f32[4] get-tuple-element(infeed), index=1 + token = token[] after-all() + infeed = ((f32[4], f32[4]), token[]) infeed(token), + sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} + infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0 + gte0 = f32[4] get-tuple-element(infeed.data), index=0 + gte1 = f32[4] get-tuple-element(infeed.data), index=1 copy0 = f32[4] copy(gte0) copy1 = f32[4] copy(gte1) ROOT add = f32[4] add(copy0, copy1) @@ -357,8 +359,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "gte0", "infeed")); - EXPECT_TRUE(HasDomainEdge(module, "gte1", "infeed")); + EXPECT_TRUE(HasDomainEdge(module, "infeed.data", "infeed")); EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0")); EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1")); @@ -366,6 +367,8 @@ ENTRY entry { // HLO passes adding unexpected instructions. // // infeed + // | + // infeed.data (tuple element 0 of infeed) // / \ // GTE0 GTE1 // / \ @@ -374,26 +377,31 @@ ENTRY entry { // \ / // TUPLE // | - // DOMAIN HloInstruction* infeed = FindInstruction(module, "infeed"); ASSERT_NE(infeed, nullptr); - auto infeed_users = infeed->users(); - HloInstruction* new_gte0 = + HloInstruction* infeed_data = infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + + auto infeed_data_users = infeed_data->users(); + HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed_data->shape(), 0), infeed_data, + 0)); HloInstruction* new_copy0 = - infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary( new_gte0->shape(), HloOpcode::kCopy, new_gte0)); - HloInstruction* new_gte1 = - infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1)); + HloInstruction* new_gte1 = infeed_data->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed_data->shape(), 1), infeed_data, + 1)); HloInstruction* new_copy1 = - infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary( new_gte1->shape(), HloOpcode::kCopy, new_gte1)); - HloInstruction* new_tuple = infeed->parent()->AddInstruction( + HloInstruction* new_tuple = infeed_data->parent()->AddInstruction( HloInstruction::CreateTuple({new_copy0, new_copy1})); - for (HloInstruction* user : infeed_users) { - TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple)); + for (HloInstruction* user : infeed_data_users) { + TF_EXPECT_OK(infeed_data->ReplaceUseWith(user, new_tuple)); } HloDomainRemover remover(ShardingMetadata::KindName(), @@ -412,7 +420,7 @@ ENTRY entry { }; for (auto& assignment : assignments) { auto device = assignment.instruction->sharding_unique_device(); - EXPECT_TRUE(device.has_value()); + ASSERT_TRUE(device.has_value()); EXPECT_EQ(*device, assignment.device); } EXPECT_TRUE(new_tuple->has_sharding()); diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index 5c5a059e0f..c170e36c73 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -57,8 +57,10 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { const string& hlo_string = R"( HloModule InfeedOutfeed ENTRY RoundTrip16MiBR1.v2 { - ROOT infeed = bf16[4]{0} infeed() - outfeed = () outfeed(infeed) + token = token[] after-all() + infeed = (bf16[4]{0}, token[]) infeed(token) + ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) } )"; auto module = CreateModuleFromHloString(hlo_string); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 33424019b9..deb7f28d84 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -902,7 +902,7 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } -Status HloEvaluator::HandleGenerateToken(HloInstruction* token) { +Status HloEvaluator::HandleAfterAll(HloInstruction* token) { evaluated_[token] = Literal::CreateToken(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index fc2fc9437b..2ad56080d8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -174,7 +174,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandleGenerateToken(HloInstruction* token) override; + Status HandleAfterAll(HloInstruction* token) override; // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index b349f7d46f..8856723f67 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -984,7 +984,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a07dbe6256..1c8c9a8d6d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -263,12 +263,30 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( CreateReducePrecision(proto.shape(), operands(0), proto.exponent_bits(), proto.mantissa_bits()); break; - case HloOpcode::kInfeed: - instruction = CreateInfeed(proto.shape(), proto.infeed_config()); - break; + case HloOpcode::kInfeed: { + const Shape& data_shape = + ShapeUtil::GetTupleElementShape(proto.shape(), 0); + if (proto.operand_ids_size() == 0) { + // TODO(b/80000000): Remove this when all uses of infeed are + // converted to take tokens. + instruction = CreateInfeed(data_shape, proto.infeed_config()); + } else { + CHECK_EQ(proto.operand_ids_size(), 2); + instruction = + CreateInfeed(data_shape, operands(0), proto.infeed_config()); + } + } break; case HloOpcode::kOutfeed: - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - proto.outfeed_config()); + if (proto.operand_ids_size() == 1) { + // TODO(b/80000000): Remove this when all uses of outfeed are + // converted to take tokens. + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + proto.outfeed_config()); + } else { + CHECK_EQ(proto.operand_ids_size(), 2); + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + operands(1), proto.outfeed_config()); + } break; case HloOpcode::kCrossReplicaSum: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) @@ -543,10 +561,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) { - return MakeUnique<HloMapInstruction>(shape, operands, map_computation, - static_operands); + HloComputation* map_computation) { + return MakeUnique<HloMapInstruction>(shape, operands, map_computation); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve( @@ -610,14 +626,28 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed( - const Shape& shape, const string& config) { - return MakeUnique<HloInfeedInstruction>(shape, config); + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config) { + return MakeUnique<HloInfeedInstruction>(infeed_shape, token_operand, config); +} + +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed( + const Shape& infeed_shape, const string& config) { + return MakeUnique<HloInfeedInstruction>(infeed_shape, config); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed( - const Shape& shape, HloInstruction* operand, + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { + return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand, + token_operand, outfeed_config); +} + +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed( + const Shape& outfeed_shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config) { - return MakeUnique<HloOutfeedInstruction>(shape, operand, outfeed_config); + return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand, + outfeed_config); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend( @@ -652,11 +682,10 @@ HloInstruction::CreateCrossReplicaSum( return MakeUnique<HloReverseInstruction>(shape, operand, dimensions); } -/* static */ std::unique_ptr<HloInstruction> -HloInstruction::CreateGenerateToken( +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll( tensorflow::gtl::ArraySlice<HloInstruction*> operands) { - auto instruction = WrapUnique(new HloInstruction( - HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape())); + auto instruction = WrapUnique( + new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -1183,8 +1212,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), user_side_metadata_->Clone()); break; - case HloOpcode::kGenerateToken: - clone = CreateGenerateToken(new_operands); + case HloOpcode::kAfterAll: + clone = CreateAfterAll(new_operands); break; } SetupDerivedInstruction(clone.get()); @@ -1369,6 +1398,30 @@ void HloInstruction::AppendOperand(HloInstruction* operand) { operand->AddUser(this); } +void HloInstruction::RemoveOperandsAtAscendingIndices( + tensorflow::gtl::ArraySlice<int> ascending_indices) { + if (ascending_indices.empty()) { + return; + } + int next_index = 0; + int removed_count = 0; + for (int to_remove : ascending_indices) { + while (next_index < to_remove) { + operands_[next_index - removed_count] = operands_[next_index]; + ++next_index; + } + CHECK_LT(to_remove, operands_.size()); + ++removed_count; + ++next_index; + } + while (next_index < operands_.size()) { + operands_[next_index - removed_count] = operands_[next_index]; + ++next_index; + } + CHECK_EQ(removed_count, ascending_indices.size()); + operands_.resize(operands_.size() - removed_count); +} + void HloInstruction::AddUser(HloInstruction* user) { if (!ContainsKey(user_set_, user)) { user_set_.insert(user); @@ -1447,7 +1500,7 @@ bool HloInstruction::IdenticalSlowPath( // These opcodes have complex or special behavior so just return false. case HloOpcode::kDomain: case HloOpcode::kWhile: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: return false; // Check dot dimension numbers. @@ -1539,6 +1592,10 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + if (user->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR( + Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands()); + } return Status::OK(); } @@ -1577,6 +1634,10 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + if (user->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR( + Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands()); + } } } users_.clear(); @@ -2226,8 +2287,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleGather(this); case HloOpcode::kDomain: return visitor->HandleDomain(this); - case HloOpcode::kGenerateToken: - return visitor->HandleGenerateToken(this); + case HloOpcode::kAfterAll: + return visitor->HandleAfterAll(this); // These opcodes are not handled here. case HloOpcode::kTrace: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8f59e67123..59a383218c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -389,11 +389,10 @@ class HloInstruction { // Creates a map instruction, where the computation (given by the handle) is // applied element-wise to every element in operands (across the operands, - // at a given index) with the same `static_operands`. + // at a given index) static std::unique_ptr<HloInstruction> CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {}); + HloComputation* map_computation); // Creates a convolution op, where rhs is the convolutional filter // and window describes how the filter is applied to lhs. @@ -459,13 +458,29 @@ class HloInstruction { const Shape& shape, HloInstruction* operand); // Creates an infeed instruction, which reads data of the given shape from the - // Infeed interface of the device. - static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape, + // Infeed interface of the device. infeed_shape is the shape of the data + // received from the infeed *not* the shape of the infeed instruction which + // is a tuple containing the infeed_shape and the TOKEN. + static std::unique_ptr<HloInstruction> CreateInfeed( + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config); + // Overload which does not require a token. + // TODO(b/80000000): Remove this overload when all uses of infeed are + // converted to take tokens. + static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& infeed_shape, const string& config); - // Creates an outfeed instruction, which outputs data. + // Creates an outfeed instruction, which outputs data. outfeed_shape is the + // shape of the data being outfed *not* the shape of the outfeed instruction + // which is a TOKEN. static std::unique_ptr<HloInstruction> CreateOutfeed( - const Shape& shape, HloInstruction* operand, + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); + // Overload which does not require a token. + // TODO(b/80000000): Remove this overload when all uses of infeed are + // converted to take tokens. + static std::unique_ptr<HloInstruction> CreateOutfeed( + const Shape& outfeed_shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config); // Creates an asynchronous send instruction with the given channel id, which @@ -665,9 +680,9 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions); - // Creates a token instruction used for joining or creating token types which - // thread through side-effecting operations. - static std::unique_ptr<HloInstruction> CreateGenerateToken( + // Creates a token instruction used for joining or creating new values of + // token type which thread through side-effecting operations. + static std::unique_ptr<HloInstruction> CreateAfterAll( tensorflow::gtl::ArraySlice<HloInstruction*> operands); // Creates an instance of GatherDimensionNumbers. @@ -811,9 +826,15 @@ class HloInstruction { // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. + // + // If user is a fusion instruction, this function will remove any duplicated + // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); // Replaces the specified operand with new_operand. + // + // This function does NOT remove duplicated operands even if this instruction + // is a fusion, so that the existing operand numbers do not change. Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If @@ -822,6 +843,9 @@ class HloInstruction { // // If this instruction is the root of its computation, sets the computation's // root to new_producer. + // + // If a user is a fusion instruction, this function will remove any duplicated + // operands of it which could be created due to this replacement. Status ReplaceAllUsesWith(HloInstruction* new_producer); // Performs a postorder DFS visit using this node as the root. If @@ -1440,6 +1464,10 @@ class HloInstruction { operands_.erase(operands_.begin() + index); } + // Removes a list of operands with the given indices in ascending order. + void RemoveOperandsAtAscendingIndices( + tensorflow::gtl::ArraySlice<int> ascending_indices); + void AppendComputation(HloComputation* computation) { called_computations_.push_back(computation); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8ee24f9d92..d8ca99dfd1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -716,10 +716,11 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { }))); auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto outfeed10 = builder.AddInstruction( - HloInstruction::CreateOutfeed(shape10, constant, "")); + HloInstruction::CreateOutfeed(shape10, constant, token, "")); auto outfeed01 = builder.AddInstruction( - HloInstruction::CreateOutfeed(shape01, constant, "")); + HloInstruction::CreateOutfeed(shape01, constant, token, "")); auto clone01 = builder.AddInstruction(outfeed01->Clone()); auto clone10 = builder.AddInstruction(outfeed10->Clone()); @@ -763,12 +764,12 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); - auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {constant}, computation_x, /*static_operands=*/{})); - auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{})); - auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{})); + auto map_1_x = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {constant}, computation_x)); + auto map_2_x = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x)); + auto map_3_y = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y)); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( @@ -1170,6 +1171,40 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { EXPECT_TRUE(StructuralEqual(*fusion, *fusion2)); } +TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { + // Fused expression: + // + // x y + // | | + // | transpose + // \ / + // dot + const Shape s = ShapeUtil::MakeShape(F32, {10, 10}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok()); + + EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y)); + EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1); +} + TEST_F(HloInstructionTest, FusionEquality) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 803fde73a5..e2f43f5810 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -553,10 +554,8 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl( HloMapInstruction::HloMapInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) + HloComputation* map_computation) : HloInstruction(HloOpcode::kMap, shape) { - CHECK(static_operands.empty()) << "static_operands not yet supported"; for (auto operand : operands) { AppendOperand(operand); } @@ -1210,6 +1209,26 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl( new_fused_computation); } +Status HloFusionInstruction::DeduplicateFusionOperands() { + tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices; + std::vector<int> operands_to_remove; + for (int i = 0; i < operand_count(); ++i) { + auto emplace_result = operand_indices.emplace(operand(i), i); + if (!emplace_result.second) { + TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith( + fused_parameter(emplace_result.first->second))); + operands_to_remove.push_back(i); + } + } + if (operands_to_remove.empty()) { + return Status::OK(); + } + TF_RETURN_IF_ERROR( + fused_instructions_computation()->RemoveUnusedParameters()); + RemoveOperandsAtAscendingIndices(operands_to_remove); + return Status::OK(); +} + HloRngInstruction::HloRngInstruction( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice<HloInstruction*> parameters) @@ -1365,9 +1384,22 @@ HloReducePrecisionInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], exponent_bits(), mantissa_bits()); } -HloInfeedInstruction::HloInfeedInstruction(const Shape& shape, +HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, + HloInstruction* token_operand, const string& config) - : HloInstruction(HloOpcode::kInfeed, shape), infeed_config_(config) {} + : HloInstruction(HloOpcode::kInfeed, + ShapeUtil::MakeTupleShape( + {infeed_shape, ShapeUtil::MakeTokenShape()})), + infeed_config_(config) { + AppendOperand(token_operand); +} + +HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, + const string& config) + : HloInstruction(HloOpcode::kInfeed, + ShapeUtil::MakeTupleShape( + {infeed_shape, ShapeUtil::MakeTokenShape()})), + infeed_config_(config) {} HloInstructionProto HloInfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); @@ -1395,19 +1427,37 @@ std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 0); - return MakeUnique<HloInfeedInstruction>(shape, infeed_config()); + if (new_operands.empty()) { + return MakeUnique<HloInfeedInstruction>(infeed_shape(), infeed_config()); + } else { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique<HloInfeedInstruction>(infeed_shape(), new_operands[0], + infeed_config()); + } } HloOutfeedInstruction::HloOutfeedInstruction( - const Shape& shape, HloInstruction* operand, + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) + : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), + outfeed_shape_(outfeed_shape), + outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) + << "Outfeed shape " << outfeed_shape + << " must be compatible with operand shape " << operand->shape(); + AppendOperand(operand); + AppendOperand(token_operand); +} + +HloOutfeedInstruction::HloOutfeedInstruction( + const Shape& outfeed_shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config) - : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()), - outfeed_shape_(shape), + : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), + outfeed_shape_(outfeed_shape), outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { - CHECK(ShapeUtil::Compatible(operand->shape(), shape)) - << "Outfeed shape " << shape << " must be compatible with operand shape " - << operand->shape(); + CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) + << "Outfeed shape " << outfeed_shape + << " must be compatible with operand shape " << operand->shape(); AppendOperand(operand); } @@ -1438,9 +1488,14 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0], - outfeed_config()); + if (new_operands.size() == 1) { + return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0], + outfeed_config()); + } else { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0], + new_operands[1], outfeed_config()); + } } HloConvolutionInstruction::HloConvolutionInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 1a2e4ae0a5..ec8a42bd3b 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -407,8 +407,7 @@ class HloMapInstruction : public HloInstruction { public: explicit HloMapInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {}); + HloComputation* map_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector<int64>& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -636,6 +635,9 @@ class HloFusionInstruction : public HloInstruction { void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } + // If multiple operands are the same instruction, keeps only one of them. + Status DeduplicateFusionOperands(); + private: // Fuses the given instruction into this fusion instruction. When add_output // is false (which is the default), instruction_to_fuse is cloned and the @@ -785,12 +787,25 @@ class HloReducePrecisionInstruction : public HloInstruction { class HloInfeedInstruction : public HloInstruction { public: - explicit HloInfeedInstruction(const Shape& shape, const string& config); + explicit HloInfeedInstruction(const Shape& infeed_shape, + HloInstruction* token_operand, + const string& config); + // TODO(b/80000000): Remove this constructor when all uses of infeed are + // converted to take tokens. + explicit HloInfeedInstruction(const Shape& infeed_shape, + const string& config); // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. string infeed_config() const { return infeed_config_; } void set_infeed_config(const string& config) { infeed_config_ = config; } + // Returns the shape of the data received by the infeed. This is not the same + // as the shape of the infeed instruction which produces a tuple containing + // the infeed data shape and a TOKEN. + const Shape& infeed_shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); + return ShapeUtil::GetSubshape(shape(), {0}); + } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -813,11 +828,19 @@ class HloInfeedInstruction : public HloInstruction { class HloOutfeedInstruction : public HloInstruction { public: - explicit HloOutfeedInstruction(const Shape& shape, HloInstruction* operand, + explicit HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); + // TODO(b/80000000): Remove this constructor when all uses of outfeed are + // converted to take tokens. + explicit HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + tensorflow::StringPiece outfeed_config); + // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); return outfeed_shape_; } // Returns the config for the Outfeed instruction. diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 8a31a8e617..b57c940238 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -187,7 +187,7 @@ HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(Ge); -HLO_MATCHER(GenerateToken); +HLO_MATCHER(AfterAll); HLO_MATCHER(Gt); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); @@ -196,6 +196,7 @@ HLO_MATCHER(Log); HLO_MATCHER(And); HLO_MATCHER(Not); HLO_MATCHER(Or); +HLO_MATCHER(Xor); HLO_MATCHER(Lt); HLO_MATCHER(Map); HLO_MATCHER(Maximum); diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 7083321276..05e47a698f 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -81,7 +81,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ - V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \ + V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kHostCompute, "host-compute") \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 774345124b..6f3f83f63a 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,7 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kConcatenate: case HloOpcode::kFusion: case HloOpcode::kMap: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 605c6ae741..57d17064c1 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -617,12 +617,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateReshape(shape, operands[0])); break; } - case HloOpcode::kGenerateToken: { + case HloOpcode::kAfterAll: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateGenerateToken(operands)); + instruction = + builder->AddInstruction(HloInstruction::CreateAfterAll(operands)); break; } case HloOpcode::kTuple: { @@ -978,23 +978,53 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kInfeed: { optional<string> config; attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands, /*expected_size=*/0) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateInfeed(shape, config ? *config : "")); + // We need to know the infeed data shape to construct the infeed + // instruction. This is the zero-th element of the tuple-shaped output of + // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail + // if the shape is not a non-empty tuple, so add guard so an error message + // can be emitted instead of a check fail + if (!ShapeUtil::IsTuple(shape) && !ShapeUtil::IsEmptyTuple(shape)) { + return Error(lexer_.GetLoc(), + "infeed must have a non-empty tuple shape"); + } + + if (operands.empty()) { + // TODO(b/80000000): Remove this when all uses of infeed are + // converted to take tokens. + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), config ? *config : "")); + } else if (operands.size() == 1) { + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), operands[0], + config ? *config : "")); + } else { + return Error(lexer_.GetLoc(), + "infeed must have exactly zero or one operands"); + } break; } case HloOpcode::kOutfeed: { optional<string> config; attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( - operands[0]->shape(), operands[0], config ? *config : "")); + if (operands.size() == 1) { + // TODO(b/80000000): Remove this when all uses of outfeed are + // converted to take tokens. + instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( + operands[0]->shape(), operands[0], config ? *config : "")); + } else if (operands.size() == 2) { + instruction = builder->AddInstruction( + HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], + operands[1], config ? *config : "")); + } else { + return Error(lexer_.GetLoc(), + "outfeed must have exactly one or two operands"); + } break; } case HloOpcode::kRng: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index d481e07f60..da1a34ae3c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -795,10 +795,14 @@ ENTRY ReduceR3ToR2.v3 { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - infeed = (u32[3]{0}, pred[]) infeed() - outfeed = () outfeed(infeed) - ROOT infeed.1 = (u32[3]{0}, pred[]) infeed() - outfeed.1 = () outfeed(infeed.1) + token = token[] after-all() + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 + infeed.1.token = token[] get-tuple-element(infeed.1), index=1 + outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) } )" @@ -1418,5 +1422,15 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); } +TEST_F(HloParserTest, NontupleInfeed) { + const string original = R"(HloModule nontuple_infeed: +ENTRY nontuple_infeed { + token = token[] after-all() + ROOT infeed = pred[] infeed(token) +})"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "infeed must have a non-empty tuple shape"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 1d6cd4cb23..fb39c6f085 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,6 +15,8 @@ limitations under the License. #include <set> +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -106,22 +108,50 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { + HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction); + // Infeed has an optional single token operand. + // TODO(b/80000000): Update when token is not optional. + if (infeed->operand_count() == 1 && + !ShapeUtil::Equal(infeed->operand(0)->shape(), + ShapeUtil::MakeTokenShape())) { + return InternalError( + "Expected infeed operand to be token-shaped, actual shape is %s:\n%s", + ShapeUtil::HumanString(infeed->operand(0)->shape()).c_str(), + infeed->ToString().c_str()); + } + + // The output of infeed is a tuple containing the data value and a token. + return CheckShape(infeed, + ShapeUtil::MakeTupleShape( + {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()})); +} + +Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { + HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction); + // Outfeed has an optional token operand (operand 1). + // TODO(b/80000000): Update when token is not optional. + if (outfeed->operand_count() == 2 && + !ShapeUtil::Equal(outfeed->operand(1)->shape(), + ShapeUtil::MakeTokenShape())) { + return InternalError( + "Expected operand 1 of outfeed to be a token, actual shape is %s:\n%s", + ShapeUtil::HumanString(outfeed->operand(1)->shape()).c_str(), + outfeed->ToString().c_str()); + } -Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // Outfeed has a separate shape field for the value which is outfed to the - // host. The shape of the instruction itself is always nil because the outfeed - // produces no HLO value in the graph. + // host. The shape of the instruction itself is always a token. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { return InternalError( - "Expected outfeed to have shape compatible with operand's shape %s, " + "Expected outfeed shape to be compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), outfeed->ToString().c_str()); } - return CheckShape(outfeed, ShapeUtil::MakeNil()); + return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleHostCompute(HloInstruction*) { @@ -426,13 +456,12 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } -Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) { +Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { std::vector<const Shape*> operand_shapes; for (const HloInstruction* operand : token->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(token, - ShapeInference::InferGenerateTokenShape(operand_shapes)); + return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes)); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -786,8 +815,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); - if (!ShapeUtil::IsScalar(operand_shape) && - !ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { + if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { return FailedPrecondition( "Implicit broadcast is not allowed in HLO." "Found non-compatible shapes for instruction %s.\n" @@ -815,9 +843,10 @@ bool ShapeContainsToken(const Shape& shape) { } // Verifies that all types entering and exiting the entry computation are -// legal. For example, TOKEN types have no Literal representation and cannot be -// on the interface of the entry computation (parameters and root instruction). +// legal. Status VerifyEntryAndExitShapes(const HloModule& module) { + // Tokens cannot be passed as entry parameters. + // TODO(b/80000000): Remove this constraint. for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { HloInstruction* param = module.entry_computation()->parameter_instruction(i); @@ -827,14 +856,6 @@ Status VerifyEntryAndExitShapes(const HloModule& module) { ShapeUtil::HumanString(param->shape()).c_str()); } } - if (ShapeContainsToken( - module.entry_computation()->root_instruction()->shape())) { - return InternalError( - "Entry root is or contains a token shape: %s", - ShapeUtil::HumanString( - module.entry_computation()->root_instruction()->shape()) - .c_str()); - } return Status::OK(); } @@ -881,7 +902,9 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); } else if (instruction->opcode() == HloOpcode::kWhile) { TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); - } else if (instruction->IsElementwise()) { + } else if (instruction->opcode() != + HloOpcode::kRng /* Rng operands are always scalar. */ + && instruction->IsElementwise()) { TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 7283b3e7dc..da6b5d2222 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -81,7 +81,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; - Status HandleGenerateToken(HloInstruction* token) override; + Status HandleAfterAll(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 9ac8635767..088cc26226 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -97,7 +97,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: case HloOpcode::kTranspose: case HloOpcode::kTuple: return false; diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 62599b376a..67e2cf6c77 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -770,9 +770,13 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { false_builder.AddInstruction( HloInstruction::CreateParameter(0, tshape, "param")); // Using infeed as layout assignment does not mess up with it. - auto infeed = - false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, "")); - false_builder.AddInstruction(HloInstruction::CreateTuple({infeed})); + auto token = + false_builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto infeed = false_builder.AddInstruction( + HloInstruction::CreateInfeed(xshape, token, "")); + auto infeed_data = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(xshape, infeed, 0)); + false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data})); } HloComputation* false_computation = module->AddEmbeddedComputation(false_builder.Build()); diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 3a6a7c25f4..f6e7578a89 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -67,22 +67,17 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); - // Update count to at least the numeric suffix value to avoid future - // colisions with this name. - generated_names_[root] = std::max(generated_names_[root], numeric_suffix); } } - int64* count = &(generated_names_[root]); - if (*count == 0) { - *count = 1; + + SequentialIdGenerator& id_generator = generated_names_[root]; + numeric_suffix = id_generator.RegisterId(numeric_suffix); + if (numeric_suffix == 0) { return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) : root; - } else { - tensorflow::strings::StrAppend(&root, separator_, *count); - // Increment lookup under old 'root' name. - (*count)++; - return root; } + tensorflow::strings::StrAppend(&root, separator_, numeric_suffix); + return root; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 4139c2700b..4423d61069 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -17,10 +17,11 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_ #include <string> -#include <unordered_map> #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -44,13 +45,40 @@ class NameUniquer { static string GetSanitizedName(const string& name); private: + // Used to track and generate new identifiers for the same instruction name + // root. + class SequentialIdGenerator { + public: + SequentialIdGenerator() = default; + + // Tries to register id as used identifier. If id is not already used, the + // id itself will be returned. Otherwise a new one will be generated, and + // returned. + int64 RegisterId(int64 id) { + if (used_.insert(id).second) { + return id; + } + while (!used_.insert(next_).second) { + ++next_; + } + return next_++; + } + + private: + // The next identifier to be tried. + int64 next_ = 0; + + // Set of all the identifiers which has been used. + tensorflow::gtl::FlatSet<int64> used_; + }; + // The string to use to separate the prefix of the name from the uniquing // integer value. string separator_; - // Map from name prefix to the number of names generated using that prefix - // so far. - std::unordered_map<string, int64> generated_names_; + // Map from name prefix to the generator data structure which tracks used + // identifiers and generates new ones. + tensorflow::gtl::FlatMap<string, SequentialIdGenerator> generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 2ec255558c..3e2592c6ac 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -54,12 +54,13 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); - EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo")); EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1")); - EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1")); - EXPECT_EQ("bar.0", uniquer.GetUniqueName("bar.-1000")); - EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); - EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); + EXPECT_EQ("foo.55.0", uniquer.GetUniqueName("foo.55.1")); + EXPECT_EQ("bar.1000", uniquer.GetUniqueName("bar.1000")); + EXPECT_EQ("bar.2000", uniquer.GetUniqueName("bar.2000")); + EXPECT_EQ("bar.-2000", uniquer.GetUniqueName("bar.-2000")); + EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.1")); } TEST_F(NameUniquerTest, PrefixHasSuffix) { @@ -77,12 +78,12 @@ TEST_F(NameUniquerTest, Sanitize) { EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54")); EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1")); - EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo_2", uniquer.GetUniqueName("foo")); // Invalid characters will be replaced with '_'. - EXPECT_EQ("bar_0", uniquer.GetUniqueName("bar<-1000")); - EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); - EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); + EXPECT_EQ("bar_1000", uniquer.GetUniqueName("bar<1000")); + EXPECT_EQ("bar_2000", uniquer.GetUniqueName("bar<2000")); + EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar_1")); // Separator is only recognized in the middle of the prefix. EXPECT_EQ("_10", uniquer.GetUniqueName( @@ -93,5 +94,15 @@ TEST_F(NameUniquerTest, Sanitize) { EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_")); } +TEST_F(NameUniquerTest, KeepNamesInRandomOrder) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo.11", uniquer.GetUniqueName("foo.11")); + EXPECT_EQ("foo.10", uniquer.GetUniqueName("foo.10")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo.1")); + EXPECT_EQ("foo.12", uniquer.GetUniqueName("foo.12")); + EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index bbc95f8630..096bbde922 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -329,7 +329,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } -/* static */ StatusOr<Shape> ShapeInference::InferGenerateTokenShape( +/* static */ StatusOr<Shape> ShapeInference::InferAfterAllShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) { for (const Shape* arg_shape : arg_shapes) { if (arg_shape->element_type() != TOKEN) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index eef6e62fc8..ad34a2aa18 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -216,11 +216,11 @@ class ShapeInference { static StatusOr<Shape> InferConcatOpShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension); - // Infers the shape produced by a kGenerateToken operation. Trivially this - // shape is always a TOKEN shape. However, ShapeInference serves two purposes: - // inferring shapes and checking operand shapes. This method verifies that the - // operand shapes are all TOKENs. - static StatusOr<Shape> InferGenerateTokenShape( + // Infers the shape produced by a kAfterAll. Trivially this shape is always a + // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes + // and checking operand shapes. This method verifies that the operand shapes + // are all TOKENs. + static StatusOr<Shape> InferAfterAllShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes); // Helper that validates the given operand shape can be converted to the diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 8831c513ee..23519e445e 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -248,7 +248,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + auto token_shape = ShapeUtil::MakeTokenShape(); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -258,25 +260,32 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* in_token = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(token_shape, param, 2)); + HloInstruction* out_token = builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_s32, gte_0, in_token, "")); builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_s32, gte_0, "")); - builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + HloInstruction::CreateTuple({gte_0, gte_1, out_token})); return module().AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); + auto* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_s32, "param")); + auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); + HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( while_shape, MakeAlwaysTrueComputation(while_shape, &module()), while_body, init_value)); - + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); module().AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, WhileLoopInvariantCodeMotion{}.Run(&module())); - EXPECT_FALSE(simplified_loop); + ASSERT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Outfeed())); @@ -287,7 +296,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // bitcast either. auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + auto token_shape = ShapeUtil::MakeTokenShape(); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -297,21 +308,29 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* in_token = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(token_shape, param, 2)); HloInstruction* bitcast_inst = builder.AddInstruction( HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction* out_token = builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, "")); builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, "")); - builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + HloInstruction::CreateTuple({gte_0, gte_1, out_token})); return module().AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); + auto* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_s32, "param")); + auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); + HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( while_shape, MakeAlwaysTrueComputation(while_shape, &module()), while_body, init_value)); + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); module().AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 619e87caa5..0536c99b67 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -208,8 +208,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); - while_body->AddInstruction( - HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + auto token = while_body->AddInstruction(HloInstruction::CreateAfterAll({})); + while_body->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index d79d329721..2ccb919acf 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -179,7 +179,9 @@ body { cond { param.c = (s32[], s32[]) parameter(0) - ROOT condition = pred[] infeed() + token = token[] after-all() + infeed = (pred[], token[]) infeed(token) + ROOT condition = pred[] get-tuple-element(infeed), index=0 } ENTRY main { diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 98c3095499..e827ec5a22 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -94,8 +95,11 @@ bool IsArrayPrimitiveType(PrimitiveType primitive_type) { // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. -bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (!ShapeUtil::SameElementType(lhs, rhs)) { +bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, + bool ignore_fp_precision) { + if ((ignore_fp_precision && + !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || + (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) { VLOG(3) << "CompareShapes: lhs element type != rhs element type"; return false; } @@ -103,7 +107,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { if (ShapeUtil::IsTuple(lhs)) { return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts); + return CompareShapes(l, r, compare_layouts, + ignore_fp_precision); }); } else if (!ShapeUtil::IsArray(lhs)) { // Non-tuple, non-array tupes such as opaque and token types are trivially @@ -170,7 +175,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal( } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true); + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, + /*ignore_fp_precision=*/false); if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -179,6 +185,18 @@ StatusOr<Shape> MakeShapeWithLayoutInternal( return equal; } +/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, + const Shape& rhs) { + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, + /*ignore_fp_precision=*/true); + if (!equal && VLOG_IS_ON(3)) { + VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = " + << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); + } + + return equal; +} + /* static */ int64 ShapeUtil::Rank(const Shape& shape) { CHECK(ShapeUtil::IsArray(shape)) << "Non-arrays do not have a rank, shape: " << shape; @@ -665,7 +683,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - return CompareShapes(lhs, rhs, /*compare_layouts=*/false); + return CompareShapes(lhs, rhs, /*compare_layouts=*/false, + /*ignore_fp_precision=*/false); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, @@ -867,6 +886,50 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { } } + TF_RETURN_IF_ERROR(ValidateShapeSize(shape)); + return Status::OK(); +} + +/* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) { + VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); + auto invalid_argument = + InvalidArgument("Shape %s size may overflow int64.", + ShapeUtil::HumanString(shape).c_str()); + if (!IsArray(shape)) { + return Status::OK(); + } + int64 shape_size; + if (LayoutUtil::IsSparseArray(shape)) { + shape_size = LayoutUtil::MaxSparseElements(shape.layout()); + shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape)); + if (shape_size < 0) { + return invalid_argument; + } + shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64)); + if (shape_size < 0) { + return invalid_argument; + } + } + + // This is intentionally unconditional: even if the shape is sparse, we want + // to verify the densified version has a reasonable size. + if (shape.dimensions().empty()) { + return Status::OK(); + } + shape_size = 1; + for (int64 dim : shape.dimensions()) { + shape_size = MultiplyWithoutOverflow(shape_size, dim); + if (shape_size < 0) { + return invalid_argument; + } + } + shape_size = MultiplyWithoutOverflow( + shape_size, ByteSizeOfPrimitiveType(shape.element_type())); + if (shape_size < 0) { + return invalid_argument; + } + + VLOG(3) << "Shape size is valid: " << shape_size; return Status::OK(); } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 02e4f41505..5ae04451d3 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -280,6 +280,9 @@ class ShapeUtil { // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); + // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. + static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Returns the rank (number of dimensions) of the given shape. // Precondition: !IsTuple(shape) static int64 Rank(const Shape& shape); @@ -699,6 +702,10 @@ class ShapeUtil { static size_t Hash(const Shape& shape); private: + // Validates the shape size is sane. This makes sure it's safe to do + // calculations in int64 without overflowing. + static Status ValidateShapeSize(const Shape& shape); + // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 606f7492ce..b6f30af381 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -242,6 +242,24 @@ TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2)); } +TEST(ShapeUtilTest, EqualIgnoringFpPrecision) { + EXPECT_TRUE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1}))); +} + +TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) { + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); +} + TEST(ShapeUtilTest, CompatibleTuples) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 8ac771ae5a..0aaa990503 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -282,7 +282,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { std::unique_ptr<GlobalData> rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); - auto sub = b.Sub(lhs_param, rhs_param); + b.Sub(lhs_param, rhs_param); std::vector<int64> expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { @@ -2456,7 +2456,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { // comparison. auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1}); auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0}); - auto result = builder.Tuple({cmp_dim_0, cmp_dim_1}); + builder.Tuple({cmp_dim_0, cmp_dim_1}); auto expected = Literal::MakeTuple( {Literal::CreateR2<bool>({{true, true}, {true, false}}).get(), diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index f3dac75a44..3489514fe8 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -252,7 +252,7 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } -XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { +XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index ca337e7884..9d4f723ed6 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -92,8 +92,8 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { auto offset = builder.ConstantR1<bfloat16>( {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(2.0f)}); - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + builder.BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, + kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4<bfloat16>( diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 3a0f51fc66..1a7f188346 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -262,7 +262,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); auto r3_parameter = builder.Parameter(1, r3_shape, "input"); - XlaOp op = BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); + BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1], spec.output_bounds[2]); @@ -516,7 +516,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { XlaOp op1 = BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder); - XlaOp op2 = BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); + BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]); diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 660ff0cad5..7c73e80d69 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -40,7 +40,7 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + builder.Add(p0, p1); auto param0_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); @@ -79,7 +79,7 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1"); - auto add = builder.Mul(p0, p1); + builder.Mul(p0, p1); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 916ffadbc7..1b929d7d2f 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -109,7 +109,7 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - auto constant = builder.ConstantLiteral( + builder.ConstantLiteral( *Literal::CreateR3FromArray3D<float>(Array3D<float>(3, 0, 2))); ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {}); @@ -125,8 +125,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - auto constant = - builder.ConstantLiteral(*Literal::CreateR3FromArray3D<float>(array3d)); + builder.ConstantLiteral(*Literal::CreateR3FromArray3D<float>(array3d)); ComputeAndCompareR3<float>(&builder, array3d, {}); } diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 3a885b4389..ba5ba3a82f 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -478,7 +478,7 @@ XLA_TEST_F(ConvertTest, ConvertBF16F32) { xla::XlaOp all_bfloats_bf16 = builder.ConstantR1<bfloat16>(all_bfloats); xla::XlaOp all_bfloats_f32 = builder.ConvertElementType(all_bfloats_bf16, F32); - xla::XlaOp all_bfloats_u32 = builder.BitcastConvertType(all_bfloats_f32, U32); + builder.BitcastConvertType(all_bfloats_f32, U32); ComputeAndCompareR1<uint32>(&builder, expected, {}); } diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 2b3390ca98..b20499f252 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -248,7 +248,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { auto empty = Literal::CreateFromShape(in_shape); XlaBuilder builder(TestName()); - auto param0 = builder.Parameter(0, in_shape, "input"); + builder.Parameter(0, in_shape, "input"); auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0fd846cef8..6a2c581aec 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -89,7 +89,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) { auto lhs = builder.ConstantR1<T>({}); auto rhs = builder.ConstantR1<T>({}); - auto result = builder.Dot(lhs, rhs); + builder.Dot(lhs, rhs); this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(0.0), {}, this->error_spec_); @@ -104,7 +104,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D<T>({{3.0f, 4.0f}}); auto rhs = builder.ConstantFromArray<T>({3.0f, 4.0f}); - auto result = builder.Dot(lhs, rhs); + builder.Dot(lhs, rhs); this->template ComputeAndCompareR1<T>(&builder, {static_cast<T>(25.0f)}, {}, this->error_spec_); @@ -115,7 +115,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR1<T>({static_cast<T>(2.0f)}); auto rhs = builder.ConstantR1<T>({static_cast<T>(3.0f)}); - auto result = builder.Dot(lhs, rhs); + builder.Dot(lhs, rhs); this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(6.0f), {}, this->error_spec_); @@ -126,7 +126,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) { XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantFromArray<T>({1.0f, 2.5f, 42.0f}); auto rhs = builder.ConstantFromArray<T>({11.0f, -1.0f, 0.5f}); - auto result = builder.Dot(lhs, rhs); + builder.Dot(lhs, rhs); this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(29.5f), {}, this->error_spec_); @@ -141,7 +141,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(0, 2)); auto rhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(2, 0)); - auto result = builder.Dot(lhs, rhs); + builder.Dot(lhs, rhs); this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 0), {}, this->error_spec_); @@ -153,7 +153,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { auto lhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(0, 2)); auto rhs = builder.ConstantR2FromArray2D<T>( {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); - auto result = builder.Dot(lhs, rhs); + builder.Dot(lhs, rhs); this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 3), {}, this->error_spec_); @@ -165,7 +165,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { auto lhs = builder.ConstantR2FromArray2D<T>( {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); auto rhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(2, 0)); - auto result = builder.Dot(lhs, rhs); + builder.Dot(lhs, rhs); this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(3, 0), {}, this->error_spec_); @@ -176,7 +176,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(2, 0)); auto rhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(0, 2)); - auto result = builder.Dot(lhs, rhs); + builder.Dot(lhs, rhs); this->template ComputeAndCompareR2<T>( &builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_); @@ -190,7 +190,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { auto param1 = builder.Parameter(1, ShapeUtil::MakeShapeWithType<T>({4, 1}), "arg1"); auto exp0 = builder.Exp(param0); - auto result = builder.Dot(exp0, param1); + builder.Dot(exp0, param1); auto lhs_handle = this->client_ @@ -231,7 +231,7 @@ class SquareMatrixDot : public DotOperationTest { .ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType<T>(); - auto result = builder.Dot( + builder.Dot( builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); @@ -492,7 +492,7 @@ class NonsquareMatrixDot : public DotOperationTest { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType<T>(); - auto result = builder.Dot( + builder.Dot( builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); @@ -524,7 +524,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType<complex64>(); - auto result = builder.Dot( + builder.Dot( builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); @@ -626,7 +626,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); - auto out = builder.DotGeneral(x, y, dnums); + builder.DotGeneral(x, y, dnums); auto x_data = this->client_ @@ -690,7 +690,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { if (transpose_rhs) { rhs_arg = builder.Transpose(rhs_arg, {1, 0}); } - auto result = builder.Dot(lhs_arg, rhs_arg); + builder.Dot(lhs_arg, rhs_arg); Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " @@ -720,8 +720,8 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, "rhs_arg_1"); auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}), "rhs_arg_2"); - auto result = builder.Dot( - lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); + builder.Dot(lhs_constant, + builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); std::unique_ptr<Array2D<T>> arg_0_value_array( new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}})); @@ -768,8 +768,8 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, "lhs_arg_1"); auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShapeWithType<T>({2, 1}), "lhs_arg_2"); - auto result = builder.Dot( - builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); + builder.Dot(builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), + rhs_constant); std::unique_ptr<Array2D<T>> arg_0_value_array( new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}})); @@ -820,7 +820,7 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D<float> expected({{96.0, 105.0, 114.0}}); ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); @@ -848,7 +848,7 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D<float> expected({{105.0}, {105.0}}); ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); @@ -856,8 +856,8 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { // TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSReverseMM)))) { + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstRHSReverseMM)))) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -879,7 +879,7 @@ XLA_TEST_F(DotOperationTest, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D<float> expected({{105.0, 105.0}}); ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); @@ -887,8 +887,8 @@ XLA_TEST_F(DotOperationTest, // TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSReverseMM)))) { + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstLHSReverseMM)))) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -910,7 +910,7 @@ XLA_TEST_F(DotOperationTest, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D<float> expected({{96.0}, {105.0}, {114.0}}); ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); @@ -918,8 +918,8 @@ XLA_TEST_F(DotOperationTest, // TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) { + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstRHSRows)))) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}, @@ -946,7 +946,7 @@ XLA_TEST_F(DotOperationTest, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D<float> expected({{126.0, 129.0, 132.0}}); ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); @@ -954,8 +954,8 @@ XLA_TEST_F(DotOperationTest, // TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) { + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstLHSRows)))) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}, @@ -982,7 +982,7 @@ XLA_TEST_F(DotOperationTest, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D<float> expected({{129.0}, {129.0}}); ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); @@ -990,8 +990,8 @@ XLA_TEST_F(DotOperationTest, // TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) { + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstRHSCols)))) { std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr<Array2D<float>> constant_rhs_array( @@ -1010,7 +1010,7 @@ XLA_TEST_F(DotOperationTest, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D<float> expected({{56.0, 168.0, 91.0}}); ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); @@ -1018,8 +1018,8 @@ XLA_TEST_F(DotOperationTest, // TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) { + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstLHSCols)))) { std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr<Array2D<float>> constant_rhs_array( @@ -1038,7 +1038,7 @@ XLA_TEST_F(DotOperationTest, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D<float> expected({{168.0}, {168.0}}); ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index e6f79b5ac5..45a5cdc896 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -557,8 +557,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { *ExecuteAndTransfer(std::move(hlo_module), {}))); } -// TODO(b/64070202): Investigate failure. -XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { +XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -800,7 +799,7 @@ void BM_ParallelFusion(int num_iters) { auto param2 = builder.Parameter(2, shape2, "param2"); auto x = builder.Mul(param0, param1); - auto y = builder.Add(x, param2); + builder.Add(x, param2); auto computation = builder.Build().ConsumeValueOrDie(); // Transfer literals to device. diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 5a70c2a9ae..77f9c33ee1 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -54,7 +54,7 @@ class LocalClientExecuteTest : public LocalClientTestBase { XLA_TEST_F(LocalClientExecuteTest, Constant) { XlaBuilder builder(TestName()); - auto y = builder.ConstantR0<float>(123.0f); + builder.ConstantR0<float>(123.0f); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); @@ -701,7 +701,7 @@ XLA_TEST_F(LocalClientExecuteTest, TestAllocator allocator(wrong_platform); XlaBuilder builder(TestName()); - auto y = builder.ConstantR0<float>(123.0f); + builder.ConstantR0<float>(123.0f); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), @@ -841,6 +841,31 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { Literal::CreateR0<int64>(123456789000LL).get()})); } +XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { + XlaBuilder builder(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {3}); + auto in = builder.Infeed(shape); + auto constant = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f}); + builder.Add(in, constant); + + std::unique_ptr<Literal> result; + std::unique_ptr<tensorflow::Thread> thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + result = ShapedBufferToLiteral(ExecuteLocallyOrDie( + builder.Build().ValueOrDie(), /*arguments=*/{})); + })); + + ASSERT_IS_OK(local_client_->TransferToInfeedLocal( + *Literal::CreateR1<float>({-5.0, 123.0, 42.0}), + local_client_->default_device_ordinal())); + + // Join the thread. + thread.reset(); + + LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result); +} + // TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel. // 2017-10-18. XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) { diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 27fd36e06a..c1f1c45c8c 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -89,7 +89,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { {1.0f, 0.0f}, // row 0 {-1.0f, 0.5f}, // row 1 }); - auto map = builder.Map({data}, add_half, {0, 1}); + builder.Map({data}, add_half, {0, 1}); std::unique_ptr<Literal> expected = Literal::CreateR2FromArray2D<T>({{1.5f, 0.5f}, // row 0 @@ -108,7 +108,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { {5.0f, 6.0f}, // row 0 {1.0f, -8.0f}, // row 1 }); - auto max = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); std::unique_ptr<Literal> expected = Literal::CreateR2FromArray2D<T>({{7.0f, 6.0f}, // row 0 @@ -139,7 +139,7 @@ class TestLinspaceMaxParametric tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); auto lhs = builder.ConstantR2FromArray2D<T>(*alhs); auto rhs = builder.ConstantR2FromArray2D<T>(*arhs); - auto max = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); Array2D<T> expected(rows, cols); for (int row = 0; row < rows; ++row) { diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index a42a19af15..6597748c8d 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -454,7 +454,8 @@ XLA_TEST_F(MultiOutputFusionTest, r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) - mul2 = f32[2,2,2]{2,1,0} multiply(p0, c1) + b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={} + mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1) ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) tuple(r1, mul, mul2) } diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 838f1b4e2f..3c3c865673 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -46,7 +46,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); ComputeAndCompareR0<float>(&builder, 3.14159f, {param0_data.get()}, ErrorSpec(0.0001f)); @@ -58,7 +58,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0"); + builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0"); ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -71,7 +71,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); + builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); ComputeAndCompareR1<float>(&builder, {3.14f, -100.25f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -84,7 +84,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter( + builder.Parameter( 0, ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}), "param0"); ComputeAndCompareR1U8(&builder, str, {param0_data.get()}); @@ -97,7 +97,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); + builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0), {param0_data.get()}, ErrorSpec(0.01f)); @@ -110,7 +110,7 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); + builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); Array2D<float> expected_array( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); @@ -142,7 +142,7 @@ XLA_TEST_F(ParamsTest, TwoParameters) { // parameters to test that the parameters are not swapped. // // {11, 22} * {10, 20} = {110, 440} - auto prod = builder.Mul(sum, param1); + builder.Mul(sum, param1); ComputeAndCompareR1<float>(&builder, {110, 440}, {param0_data.get(), param1_data.get()}, @@ -157,7 +157,7 @@ XLA_TEST_F(ParamsTest, MissingParameter) { client_->TransferToServer(*literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); + builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); auto computation_status = builder.Build(); ASSERT_NE(computation_status.status(), Status::OK()); @@ -169,12 +169,12 @@ XLA_TEST_F(ParamsTest, UnusedParameter) { std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2}); std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, literal0->shape(), "param0"); + builder.Parameter(0, literal0->shape(), "param0"); std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20}); std::unique_ptr<GlobalData> param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = builder.Parameter(1, literal1->shape(), "param1"); + builder.Parameter(1, literal1->shape(), "param1"); ComputeAndCompareR1<float>(&builder, {10, 20}, {param0_data.get(), param1_data.get()}, @@ -478,7 +478,8 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { std::unique_ptr<Literal> literal = Literal::CreateR2<float>({ - {1, 3}, {2, 4}, + {1, 3}, + {2, 4}, }); const Shape original = literal->shape(); { diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 77159efb26..f405bb3d49 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -36,20 +36,20 @@ class PredTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaOp lhs_op = builder.ConstantR0<bool>(lhs); XlaOp rhs_op = builder.ConstantR0<bool>(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0<bool>(&builder, expected, {}); } }; TEST_F(PredTest, ConstantR0PredTrue) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0<bool>(true); + builder.ConstantR0<bool>(true); ComputeAndCompareR0<bool>(&builder, true, {}); } TEST_F(PredTest, ConstantR0PredFalse) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0<bool>(false); + builder.ConstantR0<bool>(false); ComputeAndCompareR0<bool>(&builder, false, {}); } @@ -79,14 +79,13 @@ TEST_F(PredTest, ConstantR0PredCompareGt) { TEST_F(PredTest, ConstantR1Pred) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1<bool>({true, false, false, true}); + builder.ConstantR1<bool>({true, false, false, true}); ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {}); } TEST_F(PredTest, ConstantR2Pred) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2<bool>({{false, true, true}, {true, false, false}}); + builder.ConstantR2<bool>({{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { { 011 }, { 100 } @@ -97,21 +96,21 @@ TEST_F(PredTest, ConstantR2Pred) { TEST_F(PredTest, AnyR1True) { XlaBuilder builder(TestName()); auto a = builder.ConstantR1<bool>({true, false}); - TF_ASSERT_OK(Any(a, &builder).status()); + Any(a); ComputeAndCompareR0<bool>(&builder, true, {}); } TEST_F(PredTest, AnyR1False) { XlaBuilder builder(TestName()); auto a = builder.ConstantR1<bool>({false, false}); - TF_ASSERT_OK(Any(a, &builder).status()); + Any(a); ComputeAndCompareR0<bool>(&builder, false, {}); } TEST_F(PredTest, AnyR1VacuouslyFalse) { XlaBuilder builder(TestName()); auto a = builder.ConstantR1<bool>({}); - TF_ASSERT_OK(Any(a, &builder).status()); + Any(a); ComputeAndCompareR0<bool>(&builder, false, {}); } @@ -122,7 +121,7 @@ TEST_F(PredTest, AnyR2True) { {false, false, false}, {false, false, true}, }); - TF_ASSERT_OK(Any(a, &builder).status()); + Any(a); ComputeAndCompareR0<bool>(&builder, true, {}); } @@ -133,7 +132,7 @@ TEST_F(PredTest, AnyR2False) { {false, false, false}, {false, false, false}, }); - TF_ASSERT_OK(Any(a, &builder).status()); + Any(a); ComputeAndCompareR0<bool>(&builder, false, {}); } diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 1a2de6937c..ba58feea8e 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -294,9 +294,9 @@ XLA_TEST_F(PrngTest, RngUniformCrash) { XlaBuilder builder(TestName()); // This used to crash XLA during LLVM IR generation for CPUs. - auto rng_uniform = builder.RngUniform(builder.ConstantR0<int32>(0), - builder.ConstantR0<int32>(1000 * 1000), - ShapeUtil::MakeShape(S32, {})); + builder.RngUniform(builder.ConstantR0<int32>(0), + builder.ConstantR0<int32>(1000 * 1000), + ShapeUtil::MakeShape(S32, {})); SetSeed(0); ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index d671d40456..579be77b24 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -829,8 +829,8 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { auto input_activations = builder.Parameter(0, input_literal->shape(), "input"); XlaComputation add = CreateScalarAddComputation(F32, &builder); - auto sum = builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), - add, GetParam().reduce_dims); + builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, + GetParam().reduce_dims); auto expected = ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims, @@ -878,7 +878,7 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { std::unique_ptr<GlobalData> b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); auto b = builder.Parameter(0, b_literal->shape(), "b"); - auto max = builder.Reduce(b, a2, max_f32, {0}); + builder.Reduce(b, a2, max_f32, {0}); ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()}); } diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index da1b588ec4..3e5087922c 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -48,7 +48,7 @@ TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { auto b = builder.ConstantR2<int32>({{17, 19}, {23, 29}, {31, 37}}); auto c = builder.Reshape(a, {6}); auto d = builder.Reshape(b, {6}); - auto e = builder.Mul(c, d); + builder.Mul(c, d); ComputeAndCompareR1<int32>(&builder, {34, 57, 115, 203, 341, 481}, {}); } diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index a4580cd71d..fccc497550 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -125,10 +125,7 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { +XLA_TEST_P(ReshapeTest, Trivial0x3) { XlaBuilder builder(TestName()); Array2D<float> input_array(0, 3); auto input_literal = Literal::CreateR2FromArray2D(input_array); @@ -141,10 +138,7 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-05-15 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { +XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); std::unique_ptr<Literal> param0_literal = @@ -158,10 +152,7 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { +XLA_TEST_P(ReshapeTest, Trivial3x0) { XlaBuilder builder(TestName()); Array2D<float> input_array(3, 0); auto input_literal = Literal::CreateR2FromArray2D(input_array); @@ -200,12 +191,8 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Splits an empty vector into an empty matrix. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { +XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1<float>({}); XlaOp parameter; @@ -234,12 +221,8 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Transposes a 2x0 array to a 0x2 array. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { +XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 2)); XlaOp parameter; @@ -286,12 +269,8 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Transposes a 0x4 array with XlaBuilder::Transpose. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { +XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 4)); XlaOp parameter; @@ -319,13 +298,9 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D<float>(6, 0)); XlaOp parameter; @@ -338,10 +313,7 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array4D<float>(2, 3, 4, 0)); XlaOp parameter; @@ -372,11 +344,7 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 6)); XlaOp parameter; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 308d3fc78a..323635b0e6 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -51,7 +51,7 @@ class ScalarComputationsTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaOp lhs_op = builder.ConstantR0<NativeT>(lhs); XlaOp rhs_op = builder.ConstantR0<NativeT>(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0<bool>(&builder, expected, {}); } @@ -62,7 +62,7 @@ class ScalarComputationsTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaOp lhs_op = builder.ConstantR0<NativeT>(lhs); XlaOp rhs_op = builder.ConstantR0<NativeT>(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0<NativeT>(&builder, expected, {}); } }; diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 72707f2244..6d6c393655 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -38,7 +38,7 @@ TEST_F(SelectTest, SelectScalarF32True) { auto pred = builder.ConstantR0<bool>(true); auto on_true = builder.ConstantR0<float>(123.0f); auto on_false = builder.ConstantR0<float>(42.0f); - auto result = builder.Select(pred, on_true, on_false); + builder.Select(pred, on_true, on_false); ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_); } @@ -48,7 +48,7 @@ TEST_F(SelectTest, SelectScalarS32True) { auto pred = builder.ConstantR0<bool>(true); auto on_true = builder.ConstantR0<int32>(-42); auto on_false = builder.ConstantR0<int32>(42); - auto result = builder.Select(pred, on_true, on_false); + builder.Select(pred, on_true, on_false); ComputeAndCompareR0<int32>(&builder, -42, {}); } @@ -58,7 +58,7 @@ TEST_F(SelectTest, SelectScalarF32False) { auto pred = builder.ConstantR0<bool>(false); auto on_true = builder.ConstantR0<float>(123.0f); auto on_false = builder.ConstantR0<float>(42.0f); - auto result = builder.Select(pred, on_true, on_false); + builder.Select(pred, on_true, on_false); ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_); } @@ -68,7 +68,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { auto pred = builder.ConstantR1<bool>({}); auto on_true = builder.ConstantR1<float>({}); auto on_false = builder.ConstantR1<float>({}); - auto select = builder.Select(pred, on_true, on_false); + builder.Select(pred, on_true, on_false); ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); } @@ -78,7 +78,7 @@ TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { auto pred = builder.ConstantR1<bool>({false, true, false, true, false}); auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(pred, on_true, on_false); + builder.Select(pred, on_true, on_false); ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, error_spec_); @@ -93,7 +93,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { auto cmp = builder.Eq(v1, v2); auto on_true = builder.ConstantR1<float>({}); auto on_false = builder.ConstantR1<float>({}); - auto select = builder.Select(cmp, on_true, on_false); + builder.Select(cmp, on_true, on_false); ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); } @@ -107,7 +107,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { auto cmp = builder.Eq(v1, v2); auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(cmp, on_true, on_false); + builder.Select(cmp, on_true, on_false); ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, error_spec_); @@ -121,7 +121,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { auto cmp = builder.Gt(v1, v2); auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(cmp, on_true, on_false); + builder.Select(cmp, on_true, on_false); ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {}, error_spec_); @@ -141,7 +141,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { /*builder=*/&builder, /*data_handle=*/&v2); auto cmp = builder.Gt(v1, v2); - auto select = builder.Select(cmp, v1, v2); + builder.Select(cmp, v1, v2); ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, {param0_data.get(), param1_data.get()}, error_spec_); @@ -182,7 +182,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { /*builder=*/&builder, /*data_handle=*/&v2); auto cmp = builder.Gt(v1, v2); - auto select = builder.Select(cmp, v1, v2); + builder.Select(cmp, v1, v2); ComputeAndCompareR1<float>(&builder, expected_vec, {param0_data.get(), param1_data.get()}, error_spec_); @@ -199,7 +199,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f}); auto on_false = builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f}); - auto select = builder.Select(cmp, on_true, on_false); + builder.Select(cmp, on_true, on_false); ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {}, error_spec_); @@ -216,7 +216,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f}); auto on_false = builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f}); - auto select = builder.Select(cmp, on_true, on_false); + builder.Select(cmp, on_true, on_false); ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {}, error_spec_); @@ -228,7 +228,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { auto pred = builder.ConstantR0<bool>(which); auto on_true = builder.ConstantR1<float>({}); auto on_false = builder.ConstantR1<float>({}); - auto select = builder.Select(pred, on_true, on_false); + builder.Select(pred, on_true, on_false); ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); } @@ -239,7 +239,7 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { auto pred = builder.ConstantR0<bool>(true); auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f}); auto on_false = builder.ConstantR1<float>({10.0f, 5.0f}); - auto select = builder.Select(pred, on_true, on_false); + builder.Select(pred, on_true, on_false); ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_); } @@ -249,7 +249,7 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { auto pred = builder.ConstantR0<bool>(false); auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f}); auto on_false = builder.ConstantR1<float>({10.0f, 5.0f}); - auto select = builder.Select(pred, on_true, on_false); + builder.Select(pred, on_true, on_false); ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 8541698576..e9008fa48a 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -31,27 +31,29 @@ class TokenHloTest : public HloTestBase {}; XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { std::unique_ptr<HloModule> module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))); + builder.AddInstruction(HloInstruction::CreateAfterAll({})); module->AddEntryComputation(builder.Build()); - EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result, + Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken())); } XLA_TEST_F(TokenHloTest, TokenTree) { std::unique_ptr<HloModule> module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - builder.AddInstruction( - HloInstruction::CreateGenerateToken({token0, token0, token1, token2})); + auto token0 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token1 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token2 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))); + HloInstruction::CreateAfterAll({token0, token0, token1, token2})); module->AddEntryComputation(builder.Build()); - EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result, + Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken())); } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { @@ -89,24 +91,12 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); } -XLA_TEST_F(TokenHloTest, InvalidTokenRoot) { - std::unique_ptr<HloModule> module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - module->AddEntryComputation(builder.Build()); - - Status status = HloVerifier().Run(module.get()).status(); - ASSERT_IS_NOT_OK(status); - EXPECT_THAT(status.error_message(), - ::testing::HasSubstr("Entry root is or contains a token shape")); -} - XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { std::unique_ptr<HloModule> module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); - builder.AddInstruction(HloInstruction::CreateGenerateToken({param})); + builder.AddInstruction(HloInstruction::CreateAfterAll({param})); builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<int32>(123))); module->AddEntryComputation(builder.Build()); @@ -120,7 +110,7 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { XLA_TEST_F(TokenHloTest, TokenInWhileLoop) { // Thread a token around a while loop. Token is created and consumed by a - // GenerateToken instruction in the while body. + // AfterAll instruction in the while body. string module_string = R"( HloModule TokenInWhileLoop @@ -130,8 +120,8 @@ HloModule TokenInWhileLoop %constant.1 = s32[] constant(1) %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %generate-token = token[] generate-token(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token) + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) } %Cond (param: (s32[], token[])) -> pred[] { @@ -143,7 +133,7 @@ HloModule TokenInWhileLoop ENTRY %TokenInWhileLoop () -> s32[] { %zero = s32[] constant(0) - %init_token = token[] generate-token() + %init_token = token[] after-all() %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 @@ -172,13 +162,13 @@ HloModule TokenInConditional %False (param.2: s32[]) -> (s32[], token[]) { %param.2 = s32[] parameter(0) - %new_token = token[] generate-token() + %new_token = token[] after-all() ROOT %tuple = (s32[], token[]) tuple(s32[] %param.2, token[] %new_token) } ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { %param.3 = pred[] parameter(0) - %init_token = token[] generate-token() + %init_token = token[] after-all() %seven = s32[] constant(7) %cond = (s32[], token[]) conditional(pred[] %param.3, token[] %init_token, s32[] %seven), true_computation=True, false_computation=False ROOT %root = s32[] get-tuple-element((s32[], token[]) %cond), index=0 diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 85799d4cfb..86babb58c9 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -256,6 +256,18 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } +XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { + // "Copy" a token from the device. The token has no physical representation so + // no copying is actually performed, but it shouldn't fail. + // TODO(b/110532604): Add transferring the token to device when this is + // supported. + auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Literal> result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); + EXPECT_TRUE(LiteralTestUtil::Equal(*Literal::CreateToken(), *result)); +} + XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { const int64 kIterationCount = 5000; std::unique_ptr<Literal> literal1 = Literal::MakeTuple( diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index fe1e3da7ec..db85344ed6 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -39,7 +39,7 @@ class TransposeTest : public ClientLibraryTestBase { XLA_TEST_F(TransposeTest, Transpose0x0) { XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0)); - auto result = builder.Transpose(lhs, {1, 0}); + builder.Transpose(lhs, {1, 0}); ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_); } @@ -47,7 +47,7 @@ XLA_TEST_F(TransposeTest, Transpose0x0) { XLA_TEST_F(TransposeTest, Transpose0x42) { XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 42)); - auto result = builder.Transpose(lhs, {1, 0}); + builder.Transpose(lhs, {1, 0}); ComputeAndCompareR2<float>(&builder, Array2D<float>(42, 0), {}, error_spec_); } @@ -55,7 +55,7 @@ XLA_TEST_F(TransposeTest, Transpose0x42) { XLA_TEST_F(TransposeTest, Transpose7x0) { XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(7, 0)); - auto result = builder.Transpose(lhs, {1, 0}); + builder.Transpose(lhs, {1, 0}); ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 7), {}, error_spec_); } @@ -65,7 +65,7 @@ TEST_F(TransposeTest, Transpose2x2) { auto lhs = builder.ConstantR2<float>({ {1.0, 2.0}, {3.0, 4.0}, }); - auto result = builder.Transpose(lhs, {1, 0}); + builder.Transpose(lhs, {1, 0}); Array2D<float> expected({{1.0f, 3.0f}, {2.0f, 4.0f}}); @@ -75,7 +75,7 @@ TEST_F(TransposeTest, Transpose2x2) { XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D<int32>(Array3D<int32>(0, 2, 3)); - auto result = builder.Transpose(operand, {1, 2, 0}); + builder.Transpose(operand, {1, 2, 0}); ComputeAndCompareR3<int32>(&builder, Array3D<int32>(2, 3, 0), {}); } @@ -83,7 +83,7 @@ XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {1, 2, 0}); + builder.Transpose(operand, {1, 2, 0}); Array3D<int32> expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}}); @@ -93,7 +93,7 @@ TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {2, 1, 0}); + builder.Transpose(operand, {2, 1, 0}); Array3D<int32> expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}}); @@ -103,7 +103,7 @@ TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { TEST_F(TransposeTest, Transpose1x2x3_1x2x3) { XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {0, 1, 2}); + builder.Transpose(operand, {0, 1, 2}); Array3D<int32> expected({{{1, 2, 3}, {4, 5, 6}}}); diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index c3abe22797..dbbe1b49e4 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -39,7 +39,7 @@ class UnaryOpTest : public ClientLibraryTestBase { void AbsSize0TestHelper() { XlaBuilder builder(TestName()); auto arg = builder.ConstantR1<T>({}); - auto abs = builder.Abs(arg); + builder.Abs(arg); if (primitive_util::NativeToPrimitiveType<T>() == C64) { ComputeAndCompareR1<float>(&builder, {}, {}); @@ -52,7 +52,7 @@ class UnaryOpTest : public ClientLibraryTestBase { void AbsTestHelper() { XlaBuilder builder(TestName()); auto arg = builder.ConstantR1<T>({-2, 25, 0, -123, inf<T>(), -inf<T>()}); - auto abs = builder.Abs(arg); + builder.Abs(arg); ComputeAndCompareR1<T>(&builder, {2, 25, 0, 123, inf<T>(), inf<T>()}, {}); } @@ -62,7 +62,7 @@ class UnaryOpTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); auto arg = builder.ConstantR1<T>( {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()}); - auto sign = builder.Sign(arg); + builder.Sign(arg); ComputeAndCompareR1<T>(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); } @@ -98,7 +98,7 @@ void UnaryOpTest::AbsTestHelper<complex64>() { {-0.3f, 0.4f}, {0, inf<float>()}, {-inf<float>(), 0}}); - auto abs = builder.Abs(arg); + builder.Abs(arg); std::unique_ptr<Literal> expected = Literal::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()}); @@ -110,7 +110,7 @@ void UnaryOpTest::SignTestHelper<complex64>() { XlaBuilder builder(TestName()); auto arg = builder.ConstantR1<complex64>( {{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}}); - auto sign = builder.Sign(arg); + builder.Sign(arg); std::unique_ptr<Literal> expected = Literal::CreateR1<complex64>( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); @@ -196,7 +196,7 @@ XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { XlaBuilder builder(TestName()); auto arg = builder.ConstantR1<unsigned int>( {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}); - auto abs = builder.Abs(arg); + builder.Abs(arg); ComputeAndCompareR1<unsigned int>( &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {}); @@ -206,7 +206,7 @@ XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { XlaBuilder builder(TestName()); auto arg = builder.ConstantR1<unsigned int>( {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}); - auto sign = builder.Sign(arg); + builder.Sign(arg); ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {}); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index 82d301983f..9e76177483 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -58,9 +58,8 @@ TEST_F(VecOpsReduceTest, AddReduceR1F32) { auto x = builder_.ConstantR1<float>( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0<float>(&builder_, -4.2f, {}, errspec_); } @@ -72,9 +71,8 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) { std::iota(input.begin(), input.end(), 100.0f); auto x = builder_.ConstantR1<float>(input); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); float expected = std::accumulate(input.begin(), input.end(), 0.0f); ComputeAndCompareR0<float>(&builder_, expected, {}, errspec_); @@ -85,9 +83,8 @@ TEST_F(VecOpsReduceTest, MaxReduceR1F32) { auto x = builder_.ConstantR1<float>( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto max_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), max_reducer, - /*dimensions_to_reduce=*/{0}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), max_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0<float>(&builder_, 2.6f, {}, errspec_); } @@ -97,9 +94,8 @@ TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) { auto x = builder_.ConstantR1<float>( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto max_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(4.0f), max_reducer, - /*dimensions_to_reduce=*/{0}); + builder_.Reduce(x, builder_.ConstantR0<float>(4.0f), max_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0<float>(&builder_, 4.0f, {}, errspec_); } @@ -114,9 +110,8 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) { // ------ dim 1 ---------- // clang-format on - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1}); ComputeAndCompareR1<float>(&builder_, {6.0, 15.0}, {}, errspec_); } @@ -129,9 +124,8 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { {1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}); // clang-format on - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR1<float>(&builder_, {5.0, 7.0, 9.0}, {}, errspec_); } @@ -139,9 +133,8 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{2}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{2}); Array2D<float> expected_array({{6.0f, 15.0f}, {6.0f, 15.0f}, {6.0f, 15.0f}}); @@ -151,9 +144,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1}); Array2D<float> expected_array( {{5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}}); @@ -164,9 +156,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); Array2D<float> expected_array({{3.0f, 6.0f, 9.0f}, {12.0f, 15.0f, 18.0f}}); @@ -176,9 +167,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1, 2}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1, 2}); ComputeAndCompareR1<float>(&builder_, {21.0, 21.0, 21.0}, {}, errspec_); } @@ -186,9 +176,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 2}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 2}); ComputeAndCompareR1<float>(&builder_, {18.0, 45.0}, {}, errspec_); } @@ -196,9 +185,8 @@ XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 1}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 1}); ComputeAndCompareR1<float>(&builder_, {15.0, 21.0, 27.0}, {}, errspec_); } @@ -206,9 +194,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 1, 2}); + builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 1, 2}); ComputeAndCompareR0<float>(&builder_, 63.0, {}, errspec_); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 5cce7a2bf8..4f7168204f 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -52,7 +52,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) { XlaBuilder builder(TestName()); auto x = builder.ConstantR1<float>( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto exp = builder.Exp(x); + builder.Exp(x); std::vector<float> expected = {8.1662, 7.4274e-02, 13.4637, 1.8316e-02, 8.1662, 9.9742, 6.7379e-03, 4.0657e-01, @@ -70,7 +70,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) { exponents.push_back(i / static_cast<float>(count)); } auto x = builder.ConstantR1<float>(exponents); - auto exp = builder.Exp(x); + builder.Exp(x); std::vector<float> expected; expected.reserve(exponents.size()); @@ -99,7 +99,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { Array4D<float> expected(2, 2, 2, 2, expected_vector); auto x = builder.ConstantR4FromArray4D<float>(exponents); - auto exp = builder.Exp(x); + builder.Exp(x); ComputeAndCompareR4<float>(&builder, expected, {}, ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3)); @@ -161,7 +161,7 @@ XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { XlaBuilder builder(TestName()); auto x = builder.ConstantR1<float>({0.0, -0.0}); - auto exp = builder.SqrtF32(x); + builder.SqrtF32(x); ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_); } @@ -169,7 +169,7 @@ XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { XlaBuilder builder(TestName()); auto x = builder.ConstantR1<float>({16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); - auto exp = builder.SqrtF32(x); + builder.SqrtF32(x); std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); @@ -179,7 +179,7 @@ XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { XlaBuilder builder(TestName()); auto x = builder.ConstantR1<float>({16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); - auto exp = builder.Pow(x, builder.ConstantR0<float>(-.5f)); + builder.Pow(x, builder.ConstantR0<float>(-.5f)); std::vector<float> expected = {.25, 1, .03125, 2.5, 2.23607, .009000, .900025}; @@ -195,7 +195,7 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR1<float>( {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto max = builder.Map({x, y}, add, {0}); + builder.Map({x, y}, add, {0}); std::vector<float> expected = {1.7, -3.2, -0.4, -3.8, 5.9, 0.1, -6.8, 4., -1., 2.2}; @@ -208,7 +208,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) { {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR1<float>( {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto max = builder.Max(x, y); + builder.Max(x, y); std::vector<float> expected = {2.1, -0.6, 2.6, 0.2, 3.8, 2.3, -1.8, 4.9, 1.4, 1.6}; @@ -227,7 +227,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto max = builder.Max(v1, v2); + builder.Max(v1, v2); ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, {param0_data.get(), param1_data.get()}, error_spec_); @@ -267,7 +267,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto max = builder.Max(v1, v2); + builder.Max(v1, v2); ComputeAndCompareR1<float>(&builder, expected_vec, {param0_data.get(), param1_data.get()}, error_spec_); @@ -278,7 +278,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { auto x = builder.ConstantR1<float>( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR0<float>(0); - auto max = builder.Max(x, y); + builder.Max(x, y); std::vector<float> expected = {2.1, 0.0, 2.6, 0.0, 2.1, 2.3, 0.0, 0.0, 0.0, 1.6}; @@ -291,7 +291,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR1<float>( {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto min = builder.Min(x, y); + builder.Min(x, y); std::vector<float> expected = {-0.4, -2.6, -3.0, -4.0, 2.1, -2.2, -5.0, -0.9, -2.4, 0.6}; @@ -304,7 +304,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { auto one = builder.ConstantR0<float>(1); auto x = builder.ConstantR1<float>( {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Min(builder.Max(x, zero), one); + builder.Min(builder.Max(x, zero), one); std::vector<float> expected = {1.0, 0.0, 1.0, 0.3, 1.0, 0.9, 0.0, 0.1, 0.0, 0.6}; @@ -317,7 +317,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { auto one = builder.ConstantR0<float>(1); auto x = builder.ConstantR1<float>( {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Clamp(zero, x, one); + builder.Clamp(zero, x, one); std::vector<float> expected = {1.0, 0.0, 1.0, 0.3, 1.0, 0.9, 0.0, 0.1, 0.0, 0.6}; @@ -329,7 +329,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { auto zero = builder.ConstantR1<float>({0.0f, 0.0f}); auto one = builder.ConstantR1<float>({1.0f, 1.0f}); auto x = builder.ConstantR1<float>({2.1, -2.6}); - auto clamp = builder.Clamp(zero, x, one); + builder.Clamp(zero, x, one); std::vector<float> expected = {1.0, 0.0}; ComputeAndCompareR1<float>(&builder, expected, {}); @@ -341,7 +341,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { auto two = builder.ConstantR0<float>(2); auto x = builder.ConstantR1<float>( {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Clamp(one, x, two); + builder.Clamp(one, x, two); std::vector<float> expected = {2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0}; @@ -353,7 +353,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { auto zero = builder.ConstantR0<int64>(0); auto one = builder.ConstantR0<int64>(10); auto x = builder.ConstantR1<int64>({-3, 3, 9, 13}); - auto clamp = builder.Clamp(zero, x, one); + builder.Clamp(zero, x, one); std::vector<int64> expected = {0, 3, 9, 10}; ComputeAndCompareR1<int64>(&builder, expected, {}); @@ -380,7 +380,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { auto y_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value"); auto zero = builder.ConstantR0<float>(0.0); - auto clamped = builder.Clamp(zero, y_value, builder.ConstantR0<float>(5)); + builder.Clamp(zero, y_value, builder.ConstantR0<float>(5)); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); clamp = computation_status.ConsumeValueOrDie(); @@ -407,7 +407,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { { auto x = builder.ConstantR1<float>( {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto activations = builder.Map({x}, mult_relu_add, {0}); + builder.Map({x}, mult_relu_add, {0}); } std::vector<float> expected = {4.7, 0.5, 5.0, 0.5, 4.7, diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index c463f3eac5..3119456347 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -184,8 +184,7 @@ TEST_F(WhileTest, WhileWithPredicateResult) { // while (result.sum() < 15.5f) { // result = result + vector<float>(0); // } -// TODO(b/29185393): does not terminate on CPU. -TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. @@ -965,10 +964,8 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { XlaBuilder cond("cond"); auto cond_t = cond.Parameter(0, tuple_shape, "t"); - TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0), - cond.ConstantR1<float>({42, 42})), - &cond) - .status()); + Any(cond.Eq(cond.GetTupleElement(cond_t, 0), + cond.ConstantR1<float>({42, 42}))); XlaBuilder body("body"); auto body_t = body.Parameter(0, tuple_shape, "t"); @@ -997,12 +994,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { XlaBuilder cond("cond"); auto cond_t = cond.Parameter(0, element_shape, "t"); - TF_ASSERT_OK( - Any(cond.Eq(cond_t, cond.ConstantR1<float>({42, 42})), &cond).status()); + Any(cond.Eq(cond_t, cond.ConstantR1<float>({42, 42}))); XlaBuilder body("body"); - auto body_t = body.Parameter(0, element_shape, "t"); - auto e = body.Broadcast(body.ConstantR0<float>(1.0), {2}); + body.Parameter(0, element_shape, "t"); + body.Broadcast(body.ConstantR0<float>(1.0), {2}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); @@ -1029,7 +1025,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { auto body_t = body.Parameter(0, element_shape, "t"); auto tuple = body.Tuple({body_t, body.Add(body_t, body.ConstantR0<float>(1))}); - auto e = body.GetTupleElement(tuple, 1); + body.GetTupleElement(tuple, 1); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); @@ -1068,7 +1064,7 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { XlaBuilder body("body"); auto body_t = body.Parameter(0, result_shape, "t"); - auto tuple = body.Tuple( + body.Tuple( {body.Add(body.GetTupleElement(body_t, 0), body.ConstantR0<int32>(1)), body.Add(body.GetTupleElement(body_t, 1), body.ConstantR0<int32>(1))}); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 0be950cacb..b081850eb5 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -187,7 +187,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { ClientLibrary::GetOrCreateLocalClient(platform)); XlaBuilder builder(TestName()); - auto result = builder.Tanh(builder.Add( + builder.Tanh(builder.Add( builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb index 324b23c24b..44532cb078 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb @@ -190,7 +190,6 @@ " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", "\n", - "\n", " def _rnn_layer(self, chars, cell, batch_size, training):\n", " \"\"\"A single RNN layer.\n", "\n", @@ -203,13 +202,12 @@ " Returns:\n", " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", " \"\"\"\n", - " hidden_outputs = []\n", - " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " hidden_outputs = tf.TensorArray(tf.float32, 0, True)\n", " state, output = cell.zero_state(batch_size, tf.float32)\n", " for ch in chars:\n", " cell_output, (state, output) = cell.call(ch, (state, output))\n", " hidden_outputs.append(cell_output)\n", - " hidden_outputs = hidden_outputs.stack()\n", + " hidden_outputs = autograph.stack(hidden_outputs)\n", " if training:\n", " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", " return hidden_outputs\n", @@ -223,7 +221,7 @@ "\n", "\n", " def call(self, inputs, training=False):\n", - " \"\"\"The RNN model code. Uses Eager and \n", + " \"\"\"The RNN model code. Uses Eager.\n", "\n", " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", " followed by a fully connected layer with ReLU activation.\n", @@ -243,7 +241,8 @@ " seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n", "\n", " # Grab just the end-of-sequence from each output.\n", - " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " indices = (length - 1, range(batch_size))\n", + " indices = tf.stack(indices, 1)\n", " sequence_ends = tf.gather_nd(seq, indices)\n", " return self.relu_layer(sequence_ends)\n", "\n", @@ -381,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 107, "metadata": { "colab": { "autoexec": { @@ -392,9 +391,9 @@ }, "colab_type": "code", "executionInfo": { - "elapsed": 10604, + "elapsed": 5454, "status": "ok", - "timestamp": 1524095272039, + "timestamp": 1529952160455, "user": { "displayName": "", "photoUrl": "", @@ -403,7 +402,7 @@ "user_tz": 240 }, "id": "2pg1AfbxBJQq", - "outputId": "9c924b4f-06e1-4538-976c-a3e1ddac5660", + "outputId": "4aef3052-f7c7-4bb1-a0a2-73fef2e96efb", "slideshow": { "slide_type": "-" } @@ -413,7 +412,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Eval loss at step 100: 0.0674834\n" + "Eval loss at step 100: 0.0705221\n" ] } ], @@ -423,8 +422,8 @@ " 'learning_rate': 0.01,\n", "}\n", "\n", - "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", - "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv\"\n", "data_dir = \"tmp/rnn/data\"\n", "\n", "regressor = tf.estimator.Estimator(\n", @@ -457,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 108, "metadata": { "colab": { "autoexec": { @@ -468,9 +467,9 @@ }, "colab_type": "code", "executionInfo": { - "elapsed": 7990, + "elapsed": 3432, "status": "ok", - "timestamp": 1524095280105, + "timestamp": 1529952163923, "user": { "displayName": "", "photoUrl": "", @@ -479,7 +478,7 @@ "user_tz": 240 }, "id": "dxHex2tUN_10", - "outputId": "2b889e5a-b9ed-4645-bf03-d98f26c72101", + "outputId": "1ff438f2-b045-4f4e-86a0-4dae7503f6b2", "slideshow": { "slide_type": "slide" } @@ -491,12 +490,12 @@ "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3f36aa6cd0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a110\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -507,12 +506,12 @@ "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3eca67f7d0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a8d0\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -520,15 +519,15 @@ { "data": { "text/html": [ - "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e" + "\u003cdiv id=\"id3\"\u003e\u003c/div\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3eca67f8d0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a050\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -536,16 +535,16 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa22-4362-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id1\", \"borderColor\": [\"#a7a7a7\"], \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0});\n", - "//# sourceURL=js_71b9087b6d" + "window[\"8a03307e-78a7-11e8-99f9-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id3\", \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0, \"borderColor\": [\"#a7a7a7\"]});\n", + "//# sourceURL=js_dc5d7f2784" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f950\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a190\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -553,16 +552,16 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa23-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_e390445f33" + "window[\"8a03307f-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_be7950150b" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ac90\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -570,17 +569,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_241dd76d85" + "window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_d0c3bd4eaa" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222aad0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -588,17 +587,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_60c64e3d50" + "window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_f10f6eba86" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222aed0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -606,17 +605,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa26-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_14ea437cbd" + "window[\"8a033082-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ff29697179" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222abd0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -624,17 +623,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa27-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_09294c2226" + "window[\"8a033083-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_ff85295dc7" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fcd0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ab90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -642,17 +641,17 @@ { "data": { "application/javascript": [ - "window[\"ec965514-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_e5e8266997" + "window[\"8b18d8dc-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ed7aabfedb" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a110\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -660,17 +659,17 @@ { "data": { "application/javascript": [ - "window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_07a097f0ee" + "window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_c86f8feaf4" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222acd0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -678,17 +677,17 @@ { "data": { "application/javascript": [ - "window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_790d669ca8" + "window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_4d0fde6662" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f8d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ae50\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -696,17 +695,17 @@ { "data": { "application/javascript": [ - "window[\"ec965517-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_d30df771f0" + "window[\"8b18d8df-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_3f66d52720" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a210\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -714,32 +713,32 @@ { "data": { "application/javascript": [ - "window[\"ec965518-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_8a43a2da4b" + "window[\"8b18d8e0-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_375f5ae6d7" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a310\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMBJREFUeJzt3F+I1XX+x/G32zjiFERUpgaFd2JBzOg5joX4h0SiMgmM\n/uhVGIlgFBlERGB3hUEkhkRdtDfRP1ACL6KpLBqcguxCjEAkmGamQcSohFHzsxe7O6zssvsydtff\n+ns8rs758j3f8z7fiyef7/k3o7XWCiDwh4s9APC/QzCAmGAAMcEAYoIBxAQDiAkGF8XTTz9d3W63\n7rvvvhoZGakVK1Zc7JEICMYlbvXq1TU8PHyxxzjPV199VcPDw/XZZ5/V22+/XVVVM2bMuMhTkRAM\n/qt+++23+uGHH+r666+vWbNmXexxuECCcQl76qmnanx8vLZs2VIDAwP1+uuv1zfffFP3339/dTqd\nWr9+fY2MjEzvv2nTpnr55ZfrgQceqIGBgXr44Yfr5MmTVVV1+vTp2r59ey1durQ6nU5t2LChTpw4\nUVVVk5OTtWXLllq6dGmtXbu23nnnnelj7tq1q7Zt21bbt2+vJUuW1HvvvVfPPvtsHTp0qAYGBmrX\nrl1/N/fRo0dr06ZN1el06u67766hoaGqqhodHa1OpzO93zPPPFO33nrr9P3t27fXm2+++e89iZyv\ncUlbtWpVGx4ebq21NjEx0brdbjtw4EBrrbUvvviidbvdduLEidZaaxs3bmxr1qxp33//fZuammob\nN25sO3fubK219tZbb7VHH320TU1NtXPnzrXDhw+3X375pbXW2kMPPdR27NjRTp8+3Y4cOdIGBwen\nn/OVV15pN910U/voo49aa61NTU21999/vz344IPTMx48eLCtWLGitdbamTNn2po1a9qePXvamTNn\n2vDwcOvv72/Hjh2bfj2HDx9urbW2du3advvtt7ejR4+21lpbuXJlO3LkyH/qVNJas8L4f6D95edC\n+/btq5UrV9by5curqmrZsmV1880316effjq977333ls33HBD9fb21h133FFHjhypqqqenp46efJk\nHTt2rGbMmFGLFi2qyy+/vCYmJurrr7+uJ598smbOnFkLFy6sDRs21N69e6eP2d/fX6tXr66qqt7e\n3n8666FDh+rUqVP1yCOPVE9PTw0ODtaqVavqgw8+qKqqJUuW1MjISB0/fryqqtauXVtffvlljY6O\n1q+//loLFy78N501/pGeiz0A/z1jY2O1f//++vjjj6vqzyE5e/ZsLVu2bHqfa665Zvr27Nmz69Sp\nU1VVdc8999TExEQ98cQT9fPPP9e6devq8ccfr8nJybryyitr9uzZ04+bP39+HT58ePr+3Llz4xkn\nJydr3rx5522bP39+TU5OVlVVp9OpoaGhuu6666rb7Va32629e/dWb29vLV68+ALOBr+HYFzi/vbT\nh3nz5tX69etrx44dF3ycnp6e2rp1a23durXGxsZq8+bNtWDBgrrtttvqp59+qlOnTlVfX19VVY2P\nj9ecOXP+4Qz/ypw5c2p8fPy8bWNjY7VgwYKqqup2u/Xiiy/WvHnzqtPp1MDAQD333HPV29tb3W73\ngl8XF8YlySXu2muvrdHR0aqqWrduXQ0NDdXnn39e586dq6mpqRoZGakff/zxXx7n4MGD9d1339W5\nc+eqr6+venp66rLLLqu5c+dWf39/vfTSS3X69On69ttv6913361169b9rnlvueWW6uvrq9dee63O\nnj1bBw8erE8++aTuvPPOqqq68cYba9asWbVv377qdDp1xRVX1NVXX10ffvjheW+I8p8hGJe4zZs3\n1+7du6vb7db+/ftr9+7dtWfPnlq2bFmtWrWq3njjjen3OP7ZSuD48eO1bdu2Wrx4cd111121dOnS\n6Sjs3LmzRkdHa/ny5bVt27Z67LHHzrvMuRAzZ86sV199tQ4cOFCDg4P1/PPP1wsvvDC9wqj68yrj\nqquumr7U+WsoFi1a9Luek9yM1vyBDpCxwgBiggHEBAOICQYQ+z/7PYzjf/QRGVxM12z68u+2WWEA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMIDajtdYu9hDA/wYrDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4j9CY2LTAbbRbWuAAAAAElFTkSuQmCC\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABTFJREFUeJzt3C+LV30eh/HP6EZvbP4ZJmkXDA6oQdZRMIhYLIKCMGVA\nyyaLT2ERLMqEDfoUFA2y3WpRrOKoSUSECePcYUEWdsN1OzfOyr5e8ZwT3unie34cfgvb29vbAxDs\n2e0BwK9DMIBMMIBMMIBMMIBMMIBMMPipXrx4MWfOnNntGfwgweCnW1hY2O0J/CDBYEe2trZ2ewI/\nkWDwh509e3bW19fn0qVLc/z48dnY2Jhbt27NyZMn59y5c/Pw4cPvz25ubs7t27dneXl5Ll68OC9f\nvtzF5ezUX3Z7AL+mJ0+ezPr6+uzfv3+uXr0658+fn7t3787GxsbcuHFjjhw5MqdPn5579+7N27dv\n5/nz5/P169dZXV3d7ensgBMGP+T69etz8ODBef369Xz69GnW1tZm7969s7S0NFeuXJnHjx/PzMzT\np09nbW1tfvvttzl48OBcu3Ztl5ezE04Y/JBDhw7NzMy7d+/mw4cPs7y8PDMz29vb8+3btzlx4sTM\nzHz8+PH7szMzi4uLP38sfxrBYEcOHz48S0tL8+zZs/96/8CBA7OxsTFHjx6dmX8Fhl+XVxJ25Nix\nY7Nv375ZX1+fzc3N2dramjdv3nz/cfPChQvz4MGD+fz587x//34ePXq0y4vZCcHgD/v37yj27Nkz\n9+/fn1evXs3KysqcOnVq7ty5M1++fJmZmZs3b87i4uKsrKzM6urqXL58ebdm8ydY8Ac6QOWEAWSC\nAWSCAWSCAWT/s99h/P3GX3d7Avxf+9s//vkf15wwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgGxhe3t7e7dHAL8GJwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg\nEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg+x1QoZHG4XIe4gAAAABJRU5ErkJggg==\n", "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f3ecc00bf10\u003e" + "\u003cmatplotlib.figure.Figure at 0x7fcd0d02dc90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -748,17 +747,17 @@ { "data": { "application/javascript": [ - "window[\"ec965519-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_893ad561f4" + "window[\"8b18d8e1-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_34b0509660" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55c90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e850\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -766,17 +765,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_2d99e0ac17" + "window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_518a0f26fe" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6ec90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -784,17 +783,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_5c19462e32" + "window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_17eb3ff612" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55dd0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb50\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -802,17 +801,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551c-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_b9c8b7567b" + "window[\"8b18d8e4-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_99da807c8e" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55a50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -820,17 +819,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551d-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_fd05186348" + "window[\"8b18d8e5-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_dee01cb4b6" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55810\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e610\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -838,16 +837,16 @@ { "data": { "text/html": [ - "\u003cdiv class=id_888646481 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" + "\u003cdiv class=id_853612217 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3f32414810\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222aa10\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -856,17 +855,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", - "//# sourceURL=js_efef96e882" + "window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n", + "//# sourceURL=js_8c378be329" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e990\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -875,17 +874,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551f-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", - "//# sourceURL=js_6eca889864" + "window[\"8b18d8e7-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_f0b946600c" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e310\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -894,17 +893,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 input\");\n", - "//# sourceURL=js_f02070cc60" + "window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 input\");\n", + "//# sourceURL=js_9e21b1373a" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b553d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6ea90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -913,17 +912,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea973-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"].remove();\n", - "//# sourceURL=js_ed9faba660" + "window[\"8b18d8ea-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"].remove();\n", + "//# sourceURL=js_a7764968c6" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a95450\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e5d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -932,17 +931,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", - "//# sourceURL=js_f3458d7074" + "window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n", + "//# sourceURL=js_74279d3ff0" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a95250\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e890\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -951,17 +950,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea975-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", - "//# sourceURL=js_3ffd97bd6f" + "window[\"8b18d8ec-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_82b6c34cdb" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a953d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -970,17 +969,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea976-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_7f73e8bcca" + "window[\"8b18d8ed-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ff6144734a" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -1043,28 +1042,6 @@ "kind": "local" }, "name": "RNN Colorbot using Keras and Estimators", - "provenance": [ - { - "file_id": "1CtzefX39ffFibX_BqE6cRbT0UW_DdVKl", - "timestamp": 1523579810961 - }, - { - "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG", - "timestamp": 1523016192637 - }, - { - "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", - "timestamp": 1522238054357 - }, - { - "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", - "timestamp": 1521743157199 - }, - { - "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", - "timestamp": 1520522344607 - } - ], "version": "0.3.2", "views": {} }, diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index 2e0a2fcef4..7a30eb94f5 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -36,16 +36,3 @@ add_dependencies( tf_cc_while_loop tf_core_lib tf_protos_cc) - -if(tensorflow_BUILD_PYTHON_BINDINGS) - add_library(tf_c_python_api OBJECT - "${tensorflow_source_dir}/tensorflow/c/python_api.cc" - "${tensorflow_source_dir}/tensorflow/c/python_api.h" - ) - add_dependencies( - tf_c_python_api - tf_c - tf_core_lib - tf_core_framework - tf_protos_cc) -endif() diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 786ea05c74..e3b59001bc 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -456,6 +456,18 @@ add_custom_command( COMMENT "Running SWIG to generate Python wrappers" VERBATIM ) +add_library(tf_c_python_api OBJECT + "${tensorflow_source_dir}/tensorflow/c/python_api.cc" + "${tensorflow_source_dir}/tensorflow/c/python_api.h" +) +add_dependencies( + tf_c_python_api + tf_c + tf_core_lib + tf_core_framework + tf_protos_cc + tf_python_protos_cc) + set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc" diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index a2bfce0362..0fc3773475 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -269,18 +269,20 @@ class FunctionBufferResourceHandleOp : public OpKernel { std::vector<Tensor> func_args; func_args.push_back(*string_arg); + const string& source_device = ctx->device()->name(); + // Obtain and canonicalize target_device. const Tensor* target_arg; OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg)); - const string& target_device = - DeviceNameUtils::CanonicalizeDeviceName(target_arg->scalar<string>()()); + string target_device; + OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName( + target_arg->scalar<string>()(), source_device, + &target_device)); FunctionLibraryRuntime* lib = ctx->function_library(); OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library is provided.")); - const string& source_device = ctx->device()->name(); - mutex_lock l(mu_); if (!initialized_) { OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index b08132cd72..9c7040de9e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -235,6 +235,36 @@ class PrefetchToDeviceTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testPrefetchToSameDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device( + "/job:localhost/replica:0/task:0/device:CPU:0")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testPrefetchDictToDevice(self): host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) device_dataset = host_dataset.apply( diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py index af41f64286..74c1825a49 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -24,6 +24,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import ops @@ -93,9 +94,18 @@ class RevBlock(tf.keras.Model): for i in reversed(range(len(self.blocks))): block = self.blocks[i] - y_inv = x if i == 0 else block.backward(y, training=training) + if i == 0: + y_inv = x + else: + # Don't update running stats when reconstructing activations + vars_and_vals = block.get_moving_stats() + y_inv = block.backward(y, training=training) + block.restore_moving_stats(vars_and_vals) + + # Update running stats when computing gradients during training dy, grads, vars_ = block.backward_grads_and_vars( y_inv, dy, training=training) + grads_all += grads vars_all += vars_ @@ -159,17 +169,18 @@ class _Residual(tf.keras.Model): """Apply residual block to inputs.""" x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) - f_x2 = self.f.call(x2, training=training) + f_x2 = self.f(x2, training=training) # TODO(lxuechen): Replace with simpler downsampling x1_down = ops.downsample( x1, self.filters // 2, self.strides, axis=self.axis) x2_down = ops.downsample( x2, self.filters // 2, self.strides, axis=self.axis) y1 = f_x2 + x1_down - g_y1 = self.g.call(y1, training=training) # self.g(y1) gives pylint error + g_y1 = self.g(y1, training=training) y2 = g_y1 + x2_down - if not concat: # Concat option needed for correct backward grads + if not concat: # For correct backward grads return y1, y2 + return tf.concat([y1, y2], axis=self.axis) def backward(self, y, training=True): @@ -178,9 +189,9 @@ class _Residual(tf.keras.Model): assert self.strides == (1, 1) y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis) - g_y1 = self.g.call(y1, training=training) + g_y1 = self.g(y1, training=training) x2 = y2 - g_y1 - f_x2 = self.f.call(x2, training=training) + f_x2 = self.f(x2, training=training) x1 = y1 - f_x2 return tf.concat([x1, x2], axis=self.axis) @@ -216,6 +227,22 @@ class _Residual(tf.keras.Model): return tf.concat([dx1, dx2], axis=self.axis), grads, vars_ + def get_moving_stats(self): + vars_and_vals = {} + + def _is_moving_var(v): # pylint: disable=invalid-name + n = v.name + return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") + + for v in filter(_is_moving_var, self.f.variables + self.g.variables): + vars_and_vals[v] = v.read_value() + + return vars_and_vals + + def restore_moving_stats(self, vars_and_vals): + for var_, val in six.iteritems(vars_and_vals): + var_.assign(val) + def _BottleneckResidualInner(filters, strides, diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py index f4436fd925..a28ca6e3e0 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -240,13 +240,12 @@ class _ResidualTest(tf.test.TestCase): x = tf.random_normal(shape=data_shape) residual = blocks._Residual( filters=16, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = residual(x, training=True), residual(x, training=False) - x_ = residual.backward(y_tr, training=True) - # The numerical loss is alarming; reconstructed inputs could differ from - # the original inputs often by more than 1e-3 - self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) x_ = residual.backward(y_ev, training=False) - self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) + x_ = residual.backward(y_tr, training=True) # This updates moving avg + self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) def test_backward_channels_last(self): """Test `backward` function with `channels_last` data format.""" @@ -259,12 +258,12 @@ class _ResidualTest(tf.test.TestCase): strides=(1, 1), input_shape=input_shape, data_format="channels_last") + y_tr, y_ev = residual(x, training=True), residual(x, training=False) - x_ = residual.backward(y_tr, training=True) - # Egregious numerical error - self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) x_ = residual.backward(y_ev, training=False) - self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) + x_ = residual.backward(y_tr, training=True) # This updates moving avg + self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) def test_backward_grads_and_vars_channels_first(self): """Test `backward_grads` function with `channels_first` data format.""" @@ -278,6 +277,8 @@ class _ResidualTest(tf.test.TestCase): dy = tf.random_normal(shape=data_shape) residual = blocks._Residual( filters=16, strides=(1, 1), input_shape=input_shape) + + vars_and_vals = residual.get_moving_stats() dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars( x, dy=dy, training=True) dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars( @@ -289,10 +290,23 @@ class _ResidualTest(tf.test.TestCase): self.assertTrue(isinstance(vars_ev, list)) for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev, vars_ev): - if grad_tr is not None: # Batch norm moving mean, var gives None grad - self.assertEqual(grad_tr.shape, grad_ev.shape) - self.assertEqual(var_tr.shape, var_ev.shape) - self.assertEqual(grad_tr.shape, var_tr.shape) + self.assertEqual(grad_tr.shape, grad_ev.shape) + self.assertEqual(var_tr.shape, var_ev.shape) + self.assertEqual(grad_tr.shape, var_tr.shape) + + # Compare against the true gradient computed by the tape + residual.restore_moving_stats(vars_and_vals) + with tf.GradientTape(persistent=True) as tape: + tape.watch(x) + y = residual(x, training=True) + grads = tape.gradient( + y, [x] + residual.trainable_variables, output_gradients=[dy]) + dx_tr_true, grads_tr_true = grads[0], grads[1:] + + del tape + + self.assertAllClose(dx_tr, dx_tr_true, rtol=1e-1, atol=1e-1) + self.assertAllClose(grads_tr, grads_tr_true, rtol=1e-1, atol=1e-1) def test_backward_grads_and_vars_channels_last(self): """Test `backward_grads` function with `channels_last` data format.""" @@ -306,6 +320,7 @@ class _ResidualTest(tf.test.TestCase): strides=(1, 1), input_shape=input_shape, data_format="channels_last") + dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars( x, dy=dy, training=True) dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars( @@ -317,10 +332,9 @@ class _ResidualTest(tf.test.TestCase): self.assertTrue(isinstance(vars_ev, list)) for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev, vars_ev): - if grad_tr is not None: # Batch norm moving mean, var gives None grad - self.assertEqual(grad_tr.shape, grad_ev.shape) - self.assertEqual(var_tr.shape, var_ev.shape) - self.assertEqual(grad_tr.shape, var_tr.shape) + self.assertEqual(grad_tr.shape, grad_ev.shape) + self.assertEqual(var_tr.shape, var_ev.shape) + self.assertEqual(grad_tr.shape, var_tr.shape) class _ResidualInnerTest(tf.test.TestCase): diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py index 3bc69da5ad..e1d8b3a055 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py @@ -26,8 +26,6 @@ import tensorflow as tf IMAGE_HEIGHT = 32 IMAGE_WIDTH = 32 NUM_CHANNEL = 3 -NUM_TRAIN_IMG = 50000 -NUM_TEST_IMG = 10000 def get_ds_from_tfrecords(data_dir, @@ -37,8 +35,8 @@ def get_ds_from_tfrecords(data_dir, epochs=None, shuffle=True, data_format="channels_first", - num_parallel_calls=4, - prefetch=True, + num_parallel_calls=8, + prefetch=0, div255=True, dtype=tf.float32): """Returns a tf.train.Dataset object from reading tfrecords. @@ -48,11 +46,12 @@ def get_ds_from_tfrecords(data_dir, split: "train", "validation", or "test" data_aug: Apply data augmentation if True batch_size: Batch size of dataset object - epochs: Number of epochs to repeat the dataset + epochs: Number of epochs to repeat the dataset; default `None` means + repeating indefinitely shuffle: Shuffle the dataset if True data_format: `channels_first` or `channels_last` num_parallel_calls: Number of threads for dataset preprocess - prefetch: Apply prefetch for the dataset if True + prefetch: Buffer size for prefetch div255: Divide the images by 255 if True dtype: Data type of images Returns: @@ -62,7 +61,7 @@ def get_ds_from_tfrecords(data_dir, ValueError: Unknown split """ - if split not in ["train", "validation", "test"]: + if split not in ["train", "validation", "test", "train_all"]: raise ValueError("Unknown split {}".format(split)) def _parser(serialized_example): @@ -74,7 +73,11 @@ def get_ds_from_tfrecords(data_dir, "label": tf.FixedLenFeature([], tf.int64), }) image = tf.decode_raw(features["image"], tf.uint8) - image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNEL]) + # Initially reshaping to [H, W, C] does not work + image = tf.reshape(image, [NUM_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH]) + # This is needed for `tf.image.resize_image_with_crop_or_pad` + image = tf.transpose(image, [1, 2, 0]) + image = tf.cast(image, dtype) label = tf.cast(features["label"], tf.int32) @@ -93,13 +96,21 @@ def get_ds_from_tfrecords(data_dir, return image, label filename = os.path.join(data_dir, split + ".tfrecords") - dataset = tf.data.TFRecordDataset(filename).repeat(epochs) + dataset = tf.data.TFRecordDataset(filename) + dataset = dataset.repeat(epochs) dataset = dataset.map(_parser, num_parallel_calls=num_parallel_calls) + dataset = dataset.prefetch(prefetch) - if prefetch: - dataset = dataset.prefetch(batch_size) if shuffle: - dataset = dataset.shuffle(NUM_TRAIN_IMG) + # Find the right size according to the split + size = { + "train": 40000, + "validation": 10000, + "test": 10000, + "train_all": 50000 + }[split] + dataset = dataset.shuffle(size) + dataset = dataset.batch(batch_size) return dataset diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py index 263a65dc76..30b0edbf43 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/config.py +++ b/tensorflow/contrib/eager/python/examples/revnet/config.py @@ -61,12 +61,13 @@ def get_hparams_cifar_38(): config.add_hparam("max_train_iter", 80000) config.add_hparam("seed", 1234) config.add_hparam("shuffle", True) - config.add_hparam("prefetch", True) - config.add_hparam("log_every", 50) - config.add_hparam("save_every", 50) + config.add_hparam("log_every", 500) + config.add_hparam("save_every", 500) config.add_hparam("dtype", tf.float32) - config.add_hparam("eval_batch_size", 500) + config.add_hparam("eval_batch_size", 1000) config.add_hparam("div255", True) + # TODO(lxuechen): This is imprecise, when training with validation set, + # we only have 40k images in training data config.add_hparam("iters_per_epoch", 50000 // config.batch_size) config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) @@ -104,11 +105,10 @@ def get_hparams_imagenet_56(): config.add_hparam("max_train_iter", 600000) config.add_hparam("seed", 1234) config.add_hparam("shuffle", True) - config.add_hparam("prefetch", True) config.add_hparam("log_every", 50) config.add_hparam("save_every", 50) config.add_hparam("dtype", tf.float32) - config.add_hparam("eval_batch_size", 500) + config.add_hparam("eval_batch_size", 1000) config.add_hparam("div255", True) # TODO(lxuechen): Update this according to ImageNet data config.add_hparam("iters_per_epoch", 50000 // config.batch_size) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index 9ef11f8e9b..1065592509 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -19,9 +19,11 @@ from __future__ import division from __future__ import print_function import os +import sys from absl import flags import tensorflow as tf +from tqdm import tqdm from tensorflow.contrib.eager.python.examples.revnet import cifar_input from tensorflow.contrib.eager.python.examples.revnet import config as config_ from tensorflow.contrib.eager.python.examples.revnet import revnet @@ -38,28 +40,54 @@ def main(_): tf.enable_eager_execution() config = config_.get_hparams_cifar_38() - model = revnet.RevNet(config=config) - - ds_train = cifar_input.get_ds_from_tfrecords( - data_dir=FLAGS.data_dir, - split="train", - data_aug=True, - batch_size=config.batch_size, - epochs=config.epochs, - shuffle=config.shuffle, - data_format=config.data_format, - dtype=config.dtype, - prefetch=config.prefetch) - ds_validation = cifar_input.get_ds_from_tfrecords( + if FLAGS.validate: + # 40k Training set + ds_train = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="train", + data_aug=True, + batch_size=config.batch_size, + epochs=config.epochs, + shuffle=config.shuffle, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.batch_size) + # 10k Training set + ds_validation = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="validation", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + shuffle=False, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.eval_batch_size) + else: + # 50k Training set + ds_train = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="train_all", + data_aug=True, + batch_size=config.batch_size, + epochs=config.epochs, + shuffle=config.shuffle, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.batch_size) + + # Always compute loss and accuracy on whole training and test set + ds_train_one_shot = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, - split="validation", + split="train_all", data_aug=False, batch_size=config.eval_batch_size, epochs=1, + shuffle=False, data_format=config.data_format, dtype=config.dtype, - prefetch=config.prefetch) + prefetch=config.eval_batch_size) ds_test = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, @@ -67,69 +95,116 @@ def main(_): data_aug=False, batch_size=config.eval_batch_size, epochs=1, + shuffle=False, data_format=config.data_format, dtype=config.dtype, - prefetch=config.prefetch) + prefetch=config.eval_batch_size) + model = revnet.RevNet(config=config) global_step = tfe.Variable(1, trainable=False) - - def learning_rate(): # TODO(lxuechen): Remove once cl/201089859 is in place - return tf.train.piecewise_constant(global_step, config.lr_decay_steps, - config.lr_list) - - optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9) - checkpoint = tf.train.Checkpoint( + learning_rate = tf.train.piecewise_constant( + global_step, config.lr_decay_steps, config.lr_list) + optimizer = tf.train.MomentumOptimizer( + learning_rate, momentum=config.momentum) + checkpointer = tf.train.Checkpoint( optimizer=optimizer, model=model, optimizer_step=global_step) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) if FLAGS.restore: latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) - checkpoint.restore(latest_path) + checkpointer.restore(latest_path) + print("Restored latest checkpoint at path:\"{}\" " + "with global_step: {}".format(latest_path, global_step.numpy())) + sys.stdout.flush() + + warmup(model, config) for x, y in ds_train: loss = train_one_iter(model, x, y, optimizer, global_step=global_step) - if global_step % config.log_every == 0: - it_validation = ds_validation.make_one_shot_iterator() + if global_step.numpy() % config.log_every == 0: + it_train = ds_train_one_shot.make_one_shot_iterator() + acc_train, loss_train = evaluate(model, it_train) it_test = ds_test.make_one_shot_iterator() - acc_validation = evaluate(model, it_validation) - acc_test = evaluate(model, it_test) - print("Iter {}, " - "train loss {}, " - "validation accuracy {}, " - "test accuracy {}".format(global_step.numpy(), loss, acc_validation, - acc_test)) + acc_test, loss_test = evaluate(model, it_test) + if FLAGS.validate: + it_validation = ds_validation.make_one_shot_iterator() + acc_validation, loss_validation = evaluate(model, it_validation) + print("Iter {}, " + "training set accuracy {:.4f}, loss {:.4f}; " + "validation set accuracy {:.4f}, loss {:4.f}" + "test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_train, loss_train, acc_validation, + loss_validation, acc_test, loss_test)) + else: + print("Iter {}, " + "training set accuracy {:.4f}, loss {:.4f}; " + "test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_train, loss_train, acc_test, + loss_test)) + sys.stdout.flush() if FLAGS.train_dir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar("Validation accuracy", acc_validation) - tf.contrib.summary.scalar("Test accuracy", acc_test) tf.contrib.summary.scalar("Training loss", loss) + tf.contrib.summary.scalar("Test accuracy", acc_test) + if FLAGS.validate: + tf.contrib.summary.scalar("Validation accuracy", acc_validation) if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: - checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt") + saved_path = checkpointer.save( + file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) + print("Saved checkpoint at path: \"{}\" " + "with global_step: {}".format(saved_path, global_step.numpy())) + sys.stdout.flush() + +def warmup(model, config, steps=1): + mock_input = tf.random_normal((config.batch_size,) + config.input_shape) + for _ in range(steps): + model(mock_input, training=False) -def train_one_iter(model, inputs, labels, optimizer, global_step=None): + +def train_one_iter(model, + inputs, + labels, + optimizer, + global_step=None, + verbose=False): """Train for one iteration.""" - grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + if FLAGS.manual_grad: + if verbose: + print("Using manual gradients") + grads, vars_, loss = model.compute_gradients(inputs, labels) + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + else: # For correctness validation + if verbose: + print("Not using manual gradients") + with tf.GradientTape() as tape: + logits, _ = model(inputs, training=True) + loss = model.compute_loss(logits=logits, labels=labels) + grads = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) return loss.numpy() def evaluate(model, iterator): """Compute accuracy with the given dataset iterator.""" + mean_loss = tfe.metrics.Mean() accuracy = tfe.metrics.Accuracy() - for x, y in iterator: + for x, y in tqdm(iterator): logits, _ = model(x, training=False) + loss = model.compute_loss(logits=logits, labels=y) accuracy( labels=tf.cast(y, tf.int64), predictions=tf.argmax(logits, axis=1, output_type=tf.int64)) + mean_loss(loss) - return accuracy.result().numpy() + return accuracy.result().numpy(), mean_loss.result().numpy() if __name__ == "__main__": @@ -138,10 +213,18 @@ if __name__ == "__main__": default=None, help="[Optional] Directory to store the training information") flags.DEFINE_string( - "data_dir", default=None, help="Directory to load tfrecords.") + "data_dir", default=None, help="Directory to load tfrecords") flags.DEFINE_boolean( "restore", - default=True, + default=False, help="[Optional] Restore the latest checkpoint from `train_dir` if True") + flags.DEFINE_boolean( + "validate", + default=False, + help="[Optional] Use the validation set or not for hyperparameter search") + flags.DEFINE_boolean( + "manual_grad", + default=False, + help="[Optional] Use manual gradient graph to save memory") FLAGS = flags.FLAGS tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index b3b8c262b1..0228bff6fa 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -27,6 +27,7 @@ from __future__ import print_function import functools import operator +import six import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import blocks @@ -47,6 +48,7 @@ class RevNet(tf.keras.Model): self._init_block = self._construct_init_block() self._block_list = self._construct_intermediate_blocks() self._final_block = self._construct_final_block() + self._moving_stats_vars = None def _construct_init_block(self): init_block = tf.keras.Sequential( @@ -153,7 +155,6 @@ class RevNet(tf.keras.Model): def call(self, inputs, training=True): """Forward pass.""" - # Only store hidden states during training if training: saved_hidden = [inputs] @@ -181,17 +182,22 @@ class RevNet(tf.keras.Model): def compute_gradients(self, inputs, labels, training=True): """Manually computes gradients. + This method also SILENTLY updates the running averages of batch + normalization when `training` is set to True. + Args: inputs: Image tensor, either NHWC or NCHW, conforming to `data_format` labels: One-hot labels for classification - training: for batch normalization + training: Use the mini-batch stats in batch norm if set to True Returns: - list of tuple each being (grad, var) for optimizer use + list of tuples each being (grad, var) for optimizer to use """ - # Forward pass record hidden states before downsampling + # Run forward pass to record hidden states; avoid updating running averages + vars_and_vals = self.get_moving_stats() _, saved_hidden = self.call(inputs, training=training) + self.restore_moving_stats(vars_and_vals) grads_all = [] vars_all = [] @@ -201,6 +207,7 @@ class RevNet(tf.keras.Model): with tf.GradientTape() as tape: x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed tape.watch(x) + # Running stats updated below logits = self._final_block(x, training=training) loss = self.compute_loss(logits, labels) @@ -226,16 +233,38 @@ class RevNet(tf.keras.Model): with tf.GradientTape() as tape: x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed + # Running stats updated below y = self._init_block(x, training=training) grads_all += tape.gradient( y, self._init_block.trainable_variables, output_gradients=[dy]) vars_all += self._init_block.trainable_variables + # Apply weight decay grads_all = self._apply_weight_decay(grads_all, vars_all) return grads_all, vars_all, loss def _apply_weight_decay(self, grads, vars_): """Update gradients to reflect weight decay.""" - return [g + self.config.weight_decay * v for g, v in zip(grads, vars_)] + # Don't decay bias + return [ + g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g + for g, v in zip(grads, vars_) + ] + + def get_moving_stats(self): + vars_and_vals = {} + + def _is_moving_var(v): + n = v.name + return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") + + for v in filter(_is_moving_var, self.variables): + vars_and_vals[v] = v.read_value() + + return vars_and_vals + + def restore_moving_stats(self, vars_and_vals): + for var_, val in six.iteritems(vars_and_vals): + var_.assign(val) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index cb3bac13f9..a5f240436a 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -36,10 +36,11 @@ def train_one_iter(model, inputs, labels, optimizer, global_step=None): return loss -class RevnetTest(tf.test.TestCase): +class RevNetTest(tf.test.TestCase): def setUp(self): - super(RevnetTest, self).setUp() + super(RevNetTest, self).setUp() + tf.set_random_seed(1) config = config_.get_hparams_imagenet_56() shape = (config.batch_size,) + config.input_shape self.model = revnet.RevNet(config=config) @@ -56,7 +57,7 @@ class RevnetTest(tf.test.TestCase): del self.x del self.t del self.config - super(RevnetTest, self).tearDown() + super(RevNetTest, self).tearDown() def test_call(self): """Test `call` function.""" @@ -67,7 +68,8 @@ class RevnetTest(tf.test.TestCase): def test_compute_gradients(self): """Test `compute_gradients` function.""" - grads, vars_, _ = self.model.compute_gradients(inputs=self.x, labels=self.t) + grads, vars_, _ = self.model.compute_gradients( + inputs=self.x, labels=self.t, training=True) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -84,7 +86,7 @@ class RevnetTest(tf.test.TestCase): def test_compute_gradients_defun(self): """Test `compute_gradients` function with defun.""" compute_gradients = tfe.defun(self.model.compute_gradients) - grads, vars_, _ = compute_gradients(self.x, self.t) + grads, vars_, _ = compute_gradients(self.x, self.t, training=True) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -144,7 +146,7 @@ class MockIterator(object): return self._tensors -class RevnetBenchmark(tf.test.Benchmark): +class RevNetBenchmark(tf.test.Benchmark): """Eager and graph benchmarks for RevNet.""" def _train_batch_sizes(self): diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index f1c60a912c..4bb90cf81b 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -53,6 +53,18 @@ class DNNEstimator(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = DNNEstimator( head=tf.contrib.estimator.multi_label_head(n_classes=3), @@ -115,8 +127,9 @@ class DNNEstimator(estimator.Estimator): model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py index ccaf1128bf..894a295498 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -53,12 +53,19 @@ class DNNLinearCombinedEstimator(estimator.Estimator): dnn_hidden_units=[1000, 500, 100], dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -116,12 +123,16 @@ class DNNLinearCombinedEstimator(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py index 3bf4abe83d..b960b16f1b 100644 --- a/tensorflow/contrib/estimator/python/estimator/linear.py +++ b/tensorflow/contrib/estimator/python/estimator/linear.py @@ -39,6 +39,18 @@ class LinearEstimator(estimator.Estimator): feature_columns=[categorical_column_a, categorical_feature_a_x_categorical_feature_b]) + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator using the FTRL optimizer with regularization. estimator = LinearEstimator( head=tf.contrib.estimator.multi_label_head(n_classes=3), @@ -99,8 +111,9 @@ class LinearEstimator(estimator.Estimator): model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. """ diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle index 44ea2dcd90..192162cfce 100644 --- a/tensorflow/contrib/lite/java/demo/app/build.gradle +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -5,7 +5,8 @@ android { buildToolsVersion "26.0.1" defaultConfig { applicationId "android.example.com.tflitecamerademo" - minSdkVersion 15 + // Required by Camera2 API. + minSdkVersion 21 targetSdkVersion 26 versionCode 1 versionName "1.0" diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 69a2f638af..a4229f91f5 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -50,6 +50,7 @@ from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: di from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework.importer import import_graph_def @@ -269,6 +270,48 @@ class TocoConverter(object): return cls( graph_def=result[0], input_tensors=result[1], output_tensors=result[2]) + @classmethod + def from_keras_model_file(cls, + model_file, + input_arrays=None, + input_shapes=None, + output_arrays=None): + """Creates a TocoConverter class from a tf.keras model file. + + Args: + model_file: Full filepath of HDF5 file containing the tf.keras model. + input_arrays: List of input tensors to freeze graph with. Uses input + arrays from SignatureDef when none are provided. (default None) + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + + Returns: + TocoConverter class. + """ + _keras.backend.clear_session() + _keras.backend.set_learning_phase(False) + keras_model = _keras.models.load_model(model_file) + sess = _keras.backend.get_session() + + # Get input and output tensors. + if input_arrays: + input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) + else: + input_tensors = keras_model.inputs + + if output_arrays: + output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) + else: + output_tensors = keras_model.outputs + set_tensor_shapes(input_tensors, input_shapes) + + graph_def = _freeze_graph(sess, output_tensors) + return cls(graph_def, input_tensors, output_tensors) + def convert(self): """Converts a TensorFlow GraphDef based on instance variables. @@ -366,7 +409,7 @@ def _is_frozen_graph(sess): Bool. """ for op in sess.graph.get_operations(): - if op.type.startswith("Variable"): + if op.type.startswith("Variable") or op.type.endswith("VariableOp"): return False return True diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index a9475de474..ca2af5aaed 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -19,11 +19,13 @@ from __future__ import division from __future__ import print_function import os +import tempfile import numpy as np from tensorflow.contrib.lite.python import lite from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.python.interpreter import Interpreter +from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -618,5 +620,279 @@ class FromSavedModelTest(test_util.TensorFlowTestCase): self.assertTrue(tflite_model) +class FromKerasFile(test_util.TensorFlowTestCase): + + def setUp(self): + keras.backend.clear_session() + + def _getSequentialModel(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + model.predict(x) + + try: + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + finally: + os.close(fd) + return keras_file + + def testSequentialModel(self): + """Test a Sequential tf.keras model with default inputs.""" + keras_file = self._getSequentialModel() + + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testSequentialModelInputArray(self): + """Test a Sequential tf.keras model testing input arrays argument.""" + keras_file = self._getSequentialModel() + + # Invalid input array raises error. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_keras_model_file( + keras_file, input_arrays=['invalid-input']) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + # Valid input array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_arrays=['dense_input']) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + def testSequentialModelInputShape(self): + """Test a Sequential tf.keras model testing input shapes argument.""" + keras_file = self._getSequentialModel() + + # Passing in shape of invalid input array has no impact as long as all input + # arrays have a shape. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_shapes={'invalid-input': [2, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Passing in shape of valid input array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_shapes={'dense_input': [2, 3]}) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + # Check input shape from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertTrue(([2, 3] == input_details[0]['shape']).all()) + + def testSequentialModelOutputArray(self): + """Test a Sequential tf.keras model testing output arrays argument.""" + keras_file = self._getSequentialModel() + + # Invalid output array raises error. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_keras_model_file( + keras_file, output_arrays=['invalid-output']) + self.assertEqual("Invalid tensors 'invalid-output' were found.", + str(error.exception)) + + # Valid output array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, output_arrays=['time_distributed/Reshape_1']) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + def testFunctionalModel(self): + """Test a Functional tf.keras model with default inputs.""" + inputs = keras.layers.Input(shape=(3,), name='input') + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + model.predict(x) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFunctionalModelMultipleInputs(self): + """Test a Functional tf.keras model with multiple inputs and outputs.""" + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.mae], + loss_weights=[1., 0.5]) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) + + model.predict([input_a_np, input_b_np], batch_size=5) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('input_a', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('input_b', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(2, len(output_details)) + self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + self.assertEqual('dropout/Identity', output_details[1]['name']) + self.assertEqual(np.float32, output_details[1]['dtype']) + self.assertTrue(([1, 4] == output_details[1]['shape']).all()) + self.assertEqual((0., 0.), output_details[1]['quantization']) + + def testFunctionalSequentialModel(self): + """Test a Functional tf.keras model containing a Sequential model.""" + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model = keras.models.Model(model.input, model.output) + + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + model.predict(x) + + model.predict(x) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index d18a29834b..249b940f92 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -74,6 +74,9 @@ def _get_toco_converter(flags): converter_kwargs["saved_model_dir"] = flags.saved_model_dir converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) converter_kwargs["signature_key"] = flags.saved_model_signature_key + elif flags.keras_model_file: + converter_fn = lite.TocoConverter.from_keras_model_file + converter_kwargs["model_file"] = flags.keras_model_file return converter_fn(**converter_kwargs) @@ -227,6 +230,10 @@ def run_main(_): "--saved_model_dir", type=str, help="Full filepath of directory containing the SavedModel.") + input_file_group.add_argument( + "--keras_model_file", + type=str, + help="Full filepath of HDF5 file containing tf.Keras model.") # Model format flags. parser.add_argument( diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index afa6fd6957..b04d166f89 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -15,6 +15,7 @@ Table of contents: * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess) * [Exporting a GraphDef from file](#basic-graphdef-file) * [Exporting a SavedModel](#basic-savedmodel) + * [Exporting a tf.keras File](#basic-keras-file) * [Complex examples](#complex) * [Exporting a quantized GraphDef](#complex-quant) * [TensorFlow Lite Python interpreter](#interpreter) @@ -114,6 +115,51 @@ For more complex SavedModels, the optional parameters that can be passed into `output_arrays`, `tag_set` and `signature_key`. Details of each parameter are available by running `help(tf.contrib.lite.TocoConverter)`. +### Exporting a tf.keras File <a name="basic-keras-file"></a> + +The following example shows how to convert a tf.keras model into a TensorFlow +Lite FlatBuffer. + +```python +import tensorflow as tf + +converter = tf.contrib.lite.TocoConverter.from_keras_model_file("keras_model.h5") +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +The tf.keras file must contain both the model and the weights. A comprehensive +example including model construction can be seen below. + +```python +import numpy as np +import tensorflow as tf + +# Generate tf.keras model. +model = tf.keras.models.Sequential() +model.add(tf.keras.layers.Dense(2, input_shape=(3,))) +model.add(tf.keras.layers.RepeatVector(3)) +model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3))) +model.compile(loss=tf.keras.losses.MSE, + optimizer=tf.keras.optimizers.RMSprop(lr=0.0001), + metrics=[tf.keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + +x = np.random.random((1, 3)) +y = np.random.random((1, 3, 3)) +model.train_on_batch(x, y) +model.predict(x) + +# Save tf.keras model in HDF5 format. +keras_file = "keras_model.h5" +tf.keras.models.save_model(model, keras_file) + +# Convert to TensorFlow Lite model. +converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_file) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + ## Complex examples <a name="complex"></a> For models where the default value of the attributes is not sufficient, the diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index da7e5add7e..485e853e25 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -378,7 +378,7 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, for (int i = 0; i < input_flat_size; i++) { output_bool_data[i] = input_tensor.bool_val(0); } - } else if (input_tensor.int_val_size() == input_flat_size) { + } else if (input_tensor.bool_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.bool_val_size(); i++) { output_bool_data[i] = input_tensor.bool_val(i); } diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index f1ef218e74..3e41e3d0b4 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -81,6 +81,19 @@ class EagerFileTest(test_util.TensorFlowTestCase): # test here that we're calling them correctly. self.assertTrue(gfile.Exists(logdir)) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerMemory(self): + training_util.get_or_create_global_step() + logdir = self.get_temp_dir() + with summary_ops.create_file_writer( + logdir, max_queue=0, + name='t0').as_default(), summary_ops.always_record_summaries(): + summary_ops.generic('tensor', 1, '') + summary_ops.scalar('scalar', 2.0) + summary_ops.histogram('histogram', [1.0]) + summary_ops.image('image', [[[[1.0]]]]) + summary_ops.audio('audio', [[1.0]], 1.0, 1) + def testDefunSummarys(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 59e76cb575..0e41170367 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -793,6 +793,7 @@ tf_cuda_library( "framework/graph_def_util.h", "framework/graph_to_functiondef.h", "framework/kernel_def_builder.h", + "framework/kernel_def_util.h", "framework/log_memory.h", "framework/lookup_interface.h", "framework/memory_types.h", @@ -1198,6 +1199,7 @@ tf_cuda_library( hdrs = [ "common_runtime/device.h", "common_runtime/device_factory.h", + "common_runtime/function.h", "common_runtime/optimization_registry.h", "common_runtime/shape_refiner.h", "graph/algorithm.h", @@ -3377,6 +3379,7 @@ tf_cc_tests( "framework/graph_def_util_test.cc", "framework/graph_to_functiondef_test.cc", "framework/kernel_def_builder_test.cc", + "framework/kernel_def_util_test.cc", "framework/memory_types_test.cc", "framework/node_def_builder_test.cc", "framework/node_def_util_test.cc", diff --git a/tensorflow/core/framework/kernel_def_util.cc b/tensorflow/core/framework/kernel_def_util.cc new file mode 100644 index 0000000000..bbd3dd3e57 --- /dev/null +++ b/tensorflow/core/framework/kernel_def_util.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/kernel_def_util.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/kernel_def.pb_text.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +namespace { +// Helper for KernelAttrsMatch(). +bool InTypeList(DataType dt, const AttrValue& type_list) { + for (int in_list : type_list.list().type()) { + if (dt == in_list) return true; + } + return false; +} +} // namespace + +Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, + bool* match) { + *match = false; + for (const auto& constraint : kernel_def.constraint()) { + if (constraint.allowed_values().list().type_size() == 0) { + return errors::Unimplemented( + "KernelDef '", ProtoShortDebugString(kernel_def), + " has constraint on attr '", constraint.name(), + "' with unsupported type: ", + SummarizeAttrValue(constraint.allowed_values())); + } + + const AttrValue* found = attrs.Find(constraint.name()); + if (found) { + if (found->type() != DT_INVALID) { + if (!InTypeList(found->type(), constraint.allowed_values())) { + return Status::OK(); + } + } else { + if (!AttrValueHasType(*found, "list(type)").ok()) { + return errors::InvalidArgument( + "KernelDef '", ProtoShortDebugString(kernel_def), + "' has constraint on attr '", constraint.name(), + "' that has value '", SummarizeAttrValue(*found), + "' that does not have type 'type' or 'list(type)' in NodeDef " + "'", + attrs.SummarizeNode(), "'"); + } + + for (int t : found->list().type()) { + if (!InTypeList(static_cast<DataType>(t), + constraint.allowed_values())) { + return Status::OK(); + } + } + } + } else { + return errors::InvalidArgument( + "OpKernel '", kernel_def.op(), "' has constraint on attr '", + constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(), + "', KernelDef: '", ProtoShortDebugString(kernel_def), "'"); + } + } + *match = true; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/kernel_def_util.h b/tensorflow/core/framework/kernel_def_util.h new file mode 100644 index 0000000000..b973cefc4f --- /dev/null +++ b/tensorflow/core/framework/kernel_def_util.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_ + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" + +namespace tensorflow { + +// Returns whether the attrs satisfy the constraints in the kernel_def. Returns +// an error if attrs in kernel_def are not found, or have a mismatching type. +Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, + bool* match); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_ diff --git a/tensorflow/core/framework/kernel_def_util_test.cc b/tensorflow/core/framework/kernel_def_util_test.cc new file mode 100644 index 0000000000..a2e4aa82fa --- /dev/null +++ b/tensorflow/core/framework/kernel_def_util_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/kernel_def_util.h" + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +NodeDef NodeDefFromText(const string& text) { + NodeDef node_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); + return node_def; +} + +KernelDef KernelDefFromText(const string& text) { + KernelDef kernel_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &kernel_def)); + return kernel_def; +} + +class AttrsMatchTest : public ::testing::Test { + protected: + void ExpectStatus(const string& node_def_str, const string& kernel_def_str, + error::Code code) { + bool match; + auto status = KernelAttrsMatch(KernelDefFromText(kernel_def_str), + NodeDefFromText(node_def_str), &match); + LOG(INFO) << "status: " << status; + EXPECT_EQ(code, status.code()); + if (!status.ok()) { + EXPECT_FALSE(match) + << "Expect no match between the given NodeDef and KernelDef"; + } + } +}; + +TEST_F(AttrsMatchTest, ValidConstraint) { + string node_def_str = R"( + name: "ValidConstraint-op" + op: "ValidConstraint" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + )"; + string kernel_def_str = R"( + op: "ValidConstraint" + device_type: "CPU" + constraint { + name: "T" + allowed_values { + list { + type: DT_FLOAT + } + } + } + )"; + ExpectStatus(node_def_str, kernel_def_str, error::OK); +} + +TEST_F(AttrsMatchTest, BadConstraint) { + string node_def_str = R"( + name: "BadConstraint-op" + op: "BadConstraint" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + )"; + string kernel_def_str = R"( + op: "BadConstraint" + device_type: "CPU" + constraint { + name: "T" + allowed_values { + list { + type: DT_FLOAT + } + } + } + )"; + ExpectStatus(node_def_str, kernel_def_str, error::INVALID_ARGUMENT); +} + +TEST_F(AttrsMatchTest, Unimplemented) { + string node_def_str = R"( + name: "BadConstraint-op" + op: "BadConstraint" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + )"; + string kernel_def_str = R"( + op: "BadConstraint" + device_type: "CPU" + constraint { + name: "T" + allowed_values { + list { + } + } + } + )"; + ExpectStatus(node_def_str, kernel_def_str, error::UNIMPLEMENTED); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c2561b5019..8a332fa1d8 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/kernel_def.pb_text.h" +#include "tensorflow/core/framework/kernel_def_util.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -969,62 +970,6 @@ void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, namespace { -// Helper for AttrsMatch(). -bool InTypeList(DataType dt, const AttrValue& type_list) { - for (int in_list : type_list.list().type()) { - if (dt == in_list) return true; - } - return false; -} - -// Returns whether the attrs satisfy the constraints in the kernel_def. Returns -// an error if attrs in kernel_def are not found, or have a mismatching type. -Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) { - *match = false; - for (const auto& constraint : kernel_def.constraint()) { - if (constraint.allowed_values().list().type_size() == 0) { - return errors::Unimplemented( - "KernelDef '", ProtoShortDebugString(kernel_def), - " has constraint on attr '", constraint.name(), - "' with unsupported type: ", - SummarizeAttrValue(constraint.allowed_values())); - } - - const AttrValue* found = attrs.Find(constraint.name()); - if (found) { - if (found->type() != DT_INVALID) { - if (!InTypeList(found->type(), constraint.allowed_values())) { - return Status::OK(); - } - } else { - if (!AttrValueHasType(*found, "list(type)").ok()) { - return errors::InvalidArgument( - "KernelDef '", ProtoShortDebugString(kernel_def), - "' has constraint on attr '", constraint.name(), - "' that has value '", SummarizeAttrValue(*found), - "' that does not have type 'type' or 'list(type)' in NodeDef " - "'", - attrs.SummarizeNode(), "'"); - } - - for (int t : found->list().type()) { - if (!InTypeList(static_cast<DataType>(t), - constraint.allowed_values())) { - return Status::OK(); - } - } - } - } else { - return errors::InvalidArgument( - "OpKernel '", kernel_def.op(), "' has constraint on attr '", - constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(), - "', KernelDef: '", ProtoShortDebugString(kernel_def), "'"); - } - } - *match = true; - return Status::OK(); -} - static const StringPiece kKernelAttr("_kernel"); // TODO(irving): Replace with const Node& version below. @@ -1043,7 +988,7 @@ Status FindKernelRegistration(const DeviceType& device_type, // If there is a kernel registered for the op and device_type, // check that the attrs match. bool match; - TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match)); + TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_def, &match)); if (match) { if (*reg != nullptr) { return errors::InvalidArgument( diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc index 80c76df255..b5c2c2aac8 100644 --- a/tensorflow/core/graph/tensor_id.cc +++ b/tensorflow/core/graph/tensor_id.cc @@ -24,6 +24,9 @@ namespace tensorflow { TensorId::TensorId(const SafeTensorId& id) : TensorId(id.first, id.second) {} +SafeTensorId::SafeTensorId(StringPiece str, int idx) + : SafeTensorId(str.ToString(), idx) {} + SafeTensorId::SafeTensorId(const TensorId& id) : SafeTensorId(id.first.ToString(), id.second) {} diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h index bf13fc78a6..b0978b4120 100644 --- a/tensorflow/core/graph/tensor_id.h +++ b/tensorflow/core/graph/tensor_id.h @@ -68,6 +68,7 @@ struct SafeTensorId : public std::pair<string, int> { // NOTE(skyewm): this is required on some platforms. I'm not sure why the // using statement above isn't always sufficient. SafeTensorId() : Base() {} + SafeTensorId(StringPiece str, int idx); SafeTensorId(const TensorId& id); string ToString() const { diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 90be051764..d8c5d09c4d 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2519,33 +2519,32 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { bool* modified) { const auto& t = ctx().graph_properties->GetInputProperties(input->name())[i]; - for (int k = 0; k < t.shape().dim_size(); ++k) { - // Skip if t shape is not fully determined. - if (t.shape().dim(k).size() < 0) { + const auto& c = + ctx().graph_properties->GetInputProperties(input->name())[j]; + for (int k = 0; k < c.shape().dim_size(); ++k) { + // Skip if c shape is not fully determined. + if (c.shape().dim(k).size() < 0) { return Status::OK(); } } - const auto& c = - ctx().graph_properties->GetInputProperties(input->name())[j]; TensorShapeProto broadcast_shape; if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { - return errors::InvalidArgument("Cannot get broadcast shape for: ", - t.DebugString(), " and ", c.DebugString()); + return Status::OK(); } if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) { // skip if the non-constant tensor doesn't have the same shape after // broadcast. return Status::OK(); } - if (TensorShape::IsValid(t.shape()) && t.has_value()) { - Tensor tensor(t.dtype(), t.shape()); - if (!tensor.FromProto(t.value())) { + if (TensorShape::IsValid(c.shape()) && c.has_value()) { + Tensor constant(c.dtype(), c.shape()); + if (!constant.FromProto(c.value())) { return errors::InvalidArgument("Cannot parse tensor from proto: ", - t.value().DebugString()); + c.value().DebugString()); } complex128 element; - for (int k = 0; k < tensor.NumElements(); ++k) { - if (!GetElement(tensor, k, &element)) { + for (int k = 0; k < constant.NumElements(); ++k) { + if (!GetElement(constant, k, &element)) { // input data type is not supported by log1p. Skip. return Status::OK(); } @@ -2558,11 +2557,12 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x)); TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y)); node->set_op("Log1p"); - node->set_input(0, y->name()); - node->add_input(AsControlDependency(x->name())); + node->set_input(0, input->input(i)); + node->add_input(AsControlDependency(y->name())); ForwardControlDependencies(node, {input}); AddToOptimizationQueue(node); + AddToOptimizationQueue(input); AddToOptimizationQueue(x); AddToOptimizationQueue(y); *modified = true; diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc index 0de97de205..f942b1a8a9 100644 --- a/tensorflow/core/kernels/dense_update_ops.cc +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -98,6 +98,8 @@ typedef Eigen::SyclDevice SYCLDevice; TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); +// quint16 not included in QUANTZIED_TYPES +TF_CALL_quint16(REGISTER_KERNELS); #undef REGISTER_KERNELS #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index f2724735bf..fcdf6c447c 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -302,15 +302,21 @@ class RemoteCallOp : public AsyncOpKernel { ~RemoteCallOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - const Tensor* target; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); - const string& target_device = - DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()()); - FunctionLibraryRuntime* lib = ctx->function_library(); OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library is provided."), done); + + const string& source_device = lib->device()->name(); + const Tensor* target; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); + string target_device; + OP_REQUIRES_OK_ASYNC( + ctx, + DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()(), + source_device, &target_device), + done); + AttrValueMap attr_values = func_.attr(); FunctionLibraryRuntime::InstantiateOptions instantiate_opts; instantiate_opts.target = target_device; @@ -345,7 +351,7 @@ class RemoteCallOp : public AsyncOpKernel { FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); opts.runner = ctx->runner(); - opts.source_device = lib->device()->name(); + opts.source_device = source_device; if (opts.source_device != target_device) { opts.remote_execution = true; } diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index 23fdfe944a..f08dd4f750 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -133,7 +133,6 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes, bool should_select = true; for (int j = selected.size() - 1; j >= 0; --j) { iou = IOU(boxes_data, next_candidate.box_index, selected[j]); - if (iou == 0.0) continue; if (iou > iou_threshold) should_select = false; } diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc index 41494f56c5..3b9133ed7e 100644 --- a/tensorflow/core/kernels/pad_op.cc +++ b/tensorflow/core/kernels/pad_op.cc @@ -320,7 +320,7 @@ namespace functor { DECLARE_GPU_SPEC(T, 5); \ DECLARE_GPU_SPEC(T, 6); -TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPECS); TF_CALL_int8(DECLARE_GPU_SPECS); } // namespace functor @@ -353,7 +353,7 @@ TF_CALL_int8(DECLARE_GPU_SPECS); .HostMemory("constant_values"), \ PadOp<GPUDevice, T, int64>) -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); +TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNEL); TF_CALL_int8(REGISTER_GPU_KERNEL); // A special GPU kernel for int32. diff --git a/tensorflow/core/kernels/pad_op_gpu.cu.cc b/tensorflow/core/kernels/pad_op_gpu.cu.cc index 8e13e19e2e..00ec44adc2 100644 --- a/tensorflow/core/kernels/pad_op_gpu.cu.cc +++ b/tensorflow/core/kernels/pad_op_gpu.cu.cc @@ -39,7 +39,7 @@ typedef Eigen::GpuDevice GPUDevice; DEFINE_GPU_PAD_SPECS(T, int32) \ DEFINE_GPU_PAD_SPECS(T, int64) -TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); +TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_SPECS); TF_CALL_int8(DEFINE_GPU_SPECS); } // namespace tensorflow diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h index 2c0576ff10..1c130ba300 100644 --- a/tensorflow/core/lib/bfloat16/bfloat16.h +++ b/tensorflow/core/lib/bfloat16/bfloat16.h @@ -354,6 +354,18 @@ struct bfloat16 { return x; } + static bfloat16 highest() { + bfloat16 x; + x.value = 0x7F7F; // 0x1.FEp127 + return x; + } + + static bfloat16 lowest() { + bfloat16 x; + x.value = 0xFF7F; // -0x1.FEp127 + return x; + } + uint16_t value; // A value that represents "not a number". diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc index 90c3fed2e8..8c24076aa9 100644 --- a/tensorflow/core/util/device_name_utils.cc +++ b/tensorflow/core/util/device_name_utils.cc @@ -184,16 +184,65 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) { return true; } +namespace { + +void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename, + DeviceNameUtils::ParsedName* parsed_name) { + if (!parsed_name->has_job) { + parsed_name->job = parsed_basename.job; + parsed_name->has_job = true; + } + if (!parsed_name->has_replica) { + parsed_name->replica = parsed_basename.replica; + parsed_name->has_replica = true; + } + if (!parsed_name->has_task) { + parsed_name->task = parsed_basename.task; + parsed_name->has_task = true; + } + if (!parsed_name->has_type) { + parsed_name->type = parsed_basename.type; + parsed_name->has_type = true; + } + if (!parsed_name->has_id) { + parsed_name->id = parsed_basename.id; + parsed_name->has_id = true; + } +} + +} // namespace + /* static */ -string DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname) { +Status DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname, + StringPiece basename, + string* canonical_name) { + *canonical_name = ""; + ParsedName parsed_basename; + if (!ParseFullName(basename, &parsed_basename)) { + return errors::InvalidArgument("Could not parse basename: ", basename, + " into a device specification."); + } + if (!(parsed_basename.has_job && parsed_basename.has_replica && + parsed_basename.has_task && parsed_basename.has_type && + parsed_basename.has_id)) { + return errors::InvalidArgument("Basename: ", basename, + " should be fully " + "specified."); + } ParsedName parsed_name; if (ParseLocalName(fullname, &parsed_name)) { - return ParsedNameToString(parsed_name); + CompleteName(parsed_basename, &parsed_name); + *canonical_name = ParsedNameToString(parsed_name); + return Status::OK(); } if (ParseFullName(fullname, &parsed_name)) { - return ParsedNameToString(parsed_name); + CompleteName(parsed_basename, &parsed_name); + *canonical_name = ParsedNameToString(parsed_name); + return Status::OK(); } - return ""; + return errors::InvalidArgument("Could not parse ", fullname, + " into a device " + "specification."); } /* static */ diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h index 0ae28df997..4071a70836 100644 --- a/tensorflow/core/util/device_name_utils.h +++ b/tensorflow/core/util/device_name_utils.h @@ -88,10 +88,14 @@ class DeviceNameUtils { // Parses "fullname" into "*parsed". Returns true iff succeeds. static bool ParseFullName(StringPiece fullname, ParsedName* parsed); - // Canonicalizes "fullname". Accepts both legacy, newer and local versions of - // the device spec. Returns the newer version of the device spec. If we were - // unable to interpret / parse "fullname" returns "". - static string CanonicalizeDeviceName(StringPiece fullname); + // Canonicalizes "fullname" into "*canonical_name". Uses a fully specified + // basename to fill in fields that are missing. Accepts both legacy, newer + // and local versions of the device spec. Returns the newer version of the + // device spec. If we were unable to interpret / parse "fullname" returns + // an error and *canonical_name is set to "". + static Status CanonicalizeDeviceName(StringPiece fullname, + StringPiece basename, + string* canonical_name); // Returns true if "name" specifies any non-trivial constraint on the device. static bool HasSomeDetails(const ParsedName& name) { diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc index ff9c108f10..dafb3b20b9 100644 --- a/tensorflow/core/util/device_name_utils_test.cc +++ b/tensorflow/core/util/device_name_utils_test.cc @@ -467,18 +467,41 @@ TEST(DeviceNameUtilsTest, GetNamesForDeviceMappings) { } TEST(DeviceNameUtilsTest, CanonicalizeDeviceName) { - EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", - DeviceNameUtils::CanonicalizeDeviceName( - "/job:foo/replica:10/task:0/device:CPU:1")); - EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", - DeviceNameUtils::CanonicalizeDeviceName( - "/job:foo/task:0/replica:10/device:CPU:1")); - EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", - DeviceNameUtils::CanonicalizeDeviceName( - "/job:foo/task:0/replica:10/cpu:1")); - EXPECT_EQ("/device:CPU:0", DeviceNameUtils::CanonicalizeDeviceName("CPU:0")); - EXPECT_EQ("", DeviceNameUtils::CanonicalizeDeviceName( - "/job:foo/task:0/replica/cpu:1")); + string canonical_name; + { + // Good basename. + string basename = "/job:foo/replica:10/task:0/device:CPU:0"; + TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName( + "/job:foo/replica:10/task:0/device:CPU:1", basename, &canonical_name)); + EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", canonical_name); + TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName( + "/job:foo/task:0/replica:10/device:CPU:1", basename, &canonical_name)); + EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", canonical_name); + TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName( + "/job:foo/task:0/replica:10/cpu:1", basename, &canonical_name)); + EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", canonical_name); + TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName("CPU:0", basename, + &canonical_name)); + EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:0", canonical_name); + Status s = DeviceNameUtils::CanonicalizeDeviceName( + "/job:foo/task:0/replica/cpu:1", basename, &canonical_name); + EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); + EXPECT_EQ("", canonical_name); + } + + { + // Try out malformed basenames. + string fullname = "/device:CPU:0"; + + Status s = DeviceNameUtils::CanonicalizeDeviceName( + fullname, "/device:CPU:0", &canonical_name); + EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); + EXPECT_EQ("", canonical_name); + s = DeviceNameUtils::CanonicalizeDeviceName( + fullname, "/job:foo/task:0/replica/cpu:1", &canonical_name); + EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); + EXPECT_EQ("", canonical_name); + } } static void BM_ParseFullName(int iters) { diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h index ee43945a39..90672a10a8 100644 --- a/tensorflow/core/util/saved_tensor_slice_util.h +++ b/tensorflow/core/util/saved_tensor_slice_util.h @@ -123,6 +123,7 @@ TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32); TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32); TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32); TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32); #undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX #undef TENSOR_PROTO_EXTRACT_TYPE_HELPER diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index f7e116bf0f..ce43d09b63 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -1308,12 +1308,10 @@ See also : : : parameters of type T and M of : : : : arbitrary type : | `dimensions` | `int64` array | array of map dimensions | -| `static_operands` | sequence of M `XlaOp`s | M arrays of arbitrary type | Applies a scalar function over the given `operands` arrays, producing an array of the same dimensions where each element is the result of the mapped function -applied to the corresponding elements in the input arrays with `static_operands` -given as additional input to `computation`. +applied to the corresponding elements in the input arrays. The mapped function is an arbitrary computation with the restriction that it has N inputs of scalar type `T` and a single output with type `S`. The output has diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 5d9a5130a0..f19bdeaa39 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3925,7 +3925,7 @@ tf_cuda_library( tf_py_test( name = "session_test", - size = "small", + size = "medium", srcs = ["client/session_test.py"], additional_deps = [ ":array_ops", diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 35aa37ac6d..f3b788f931 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1291,7 +1291,7 @@ class BaseSession(SessionInterface): raise type(e)(node_def, op, message) def _extend_graph(self): - with self._graph._lock: # pylint: disable=protected-access + with self._graph._session_run_lock(): # pylint: disable=protected-access tf_session.ExtendSession(self._session) # The threshold to run garbage collection to delete dead tensors. diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index e49d067105..b72e029d1c 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import collections +import random import os import sys import threading @@ -1040,40 +1041,72 @@ class SessionTest(test_util.TensorFlowTestCase): for t in threads: t.join() - def testParallelRunAndBuild(self): + @staticmethod + def _build_graph(): + time.sleep(random.random() * 0.1) + # Do some graph construction. Try to exercise non-trivial paths. + graph = ops.get_default_graph() + gdef = None + for _ in range(10): + x = array_ops.placeholder(dtype=dtypes.float32) + with ops.colocate_with(x): + y = array_ops.placeholder(dtype=dtypes.float32) + with ops.device('/cpu:0'): + z = control_flow_ops.while_loop( + lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) + with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): + gradients_impl.gradients(z, [x, y]) + if gdef is None: + gdef = graph.as_graph_def() + else: + importer.import_graph_def(gdef, name='import') + + def testParallelRunAndSingleBuild(self): with session.Session() as sess: c = constant_op.constant(5.0) stop = threading.Event() def run_loop(): while not stop.is_set(): + time.sleep(random.random() * 0.1) self.assertEqual(sess.run(c), 5.0) - threads = [self.checkedThread(target=run_loop) for _ in range(100)] + threads = [self.checkedThread(target=run_loop) for _ in range(10)] for t in threads: t.start() - # Do some graph construction. Try to exercise non-trivial paths. - graph = ops.get_default_graph() - gdef = None - for _ in range(10): - x = array_ops.placeholder(dtype=dtypes.float32) - with ops.colocate_with(x): - y = array_ops.placeholder(dtype=dtypes.float32) - with ops.device('/cpu:0'): - z = control_flow_ops.while_loop( - lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) - with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): - gradients_impl.gradients(z, [x, y]) - if gdef is None: - gdef = graph.as_graph_def() - else: - importer.import_graph_def(gdef, name='import') + SessionTest._build_graph() stop.set() for t in threads: t.join() + def testParallelRunAndParallelBuild(self): + with session.Session() as sess: + c = constant_op.constant(5.0) + stop = threading.Event() + + def run_loop(): + while not stop.is_set(): + time.sleep(random.random() * 0.1) + self.assertEqual(sess.run(c), 5.0) + + run_threads = [self.checkedThread(target=run_loop) for _ in range(10)] + for t in run_threads: + t.start() + + build_threads = [self.checkedThread(target=SessionTest._build_graph) + for _ in range(10)] + for t in build_threads: + t.start() + for t in build_threads: + t.join() + + # Let the run_threads run until the build threads are finished. + stop.set() + for t in run_threads: + t.join() + def testRunFeedDict(self): with session.Session() as s: x = array_ops.zeros([2]) diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py index 50bb0837b7..c3d42b49af 100644 --- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py @@ -18,9 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from absl.testing import parameterized import numpy as np +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -461,5 +464,55 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): 5, padded_shapes=shape_as_tensor) +class BatchDatasetBenchmark(test.Benchmark): + + def benchmarkBatchSparse(self): + non_zeros_per_row_values = [0, 1, 5, 10, 100] + batch_size_values = [1, 32, 64, 128, 1024] + + sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64) + batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[]) + + dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat( + ).batch(batch_size_placeholder) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + for non_zeros_per_row in non_zeros_per_row_values: + + sparse_value = sparse_tensor.SparseTensorValue( + indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis], + values=np.arange(non_zeros_per_row, dtype=np.int64), + dense_shape=[1000]) + + for batch_size in batch_size_values: + + with session.Session() as sess: + sess.run(iterator.initializer, feed_dict={ + sparse_placeholder: sparse_value, + batch_size_placeholder: batch_size}) + # Run five steps to warm up the session caches before taking the + # first measurement. + for _ in range(5): + sess.run(next_element.indices.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.indices.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100.0 + + print('Batch sparse dataset non-zeros per row: %d batch_size: %d ' + 'wall time: %f' + % (non_zeros_per_row, batch_size, median_wall_time)) + self.report_benchmark( + iters=10000, wall_time=median_wall_time, + name='benchmark_batch_sparse_dataset_nnz_%d_batch_size_%d' % ( + non_zeros_per_row, batch_size)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index fc68e945c0..a81ef90513 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -47,8 +47,11 @@ def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: - captured_value = graph_placeholder( - dtype=dtype or value.dtype, shape=value.shape, name=name) + # Note: setting ops.control_dependencies(None) ensures we always put + # capturing placeholders outside of any control flow context. + with ops.control_dependencies(None): + captured_value = graph_placeholder( + dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: if ops._USE_C_SHAPES: # pylint: disable=protected-access if isinstance(value, ops.EagerTensor): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index a5df3ef530..9e5754fc4c 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -210,6 +210,21 @@ class FunctionTest(test.TestCase): compiled = function.defun(f) compiled() + def testVariableInLoopInFunction(self): + + @function.defun + def test_function(): + + def loop_test(_): + return False + + def loop_body(_): + return variable_scope.get_variable('a', shape=()) + + return control_flow_ops.while_loop(loop_test, loop_body, [0.0]) + + self.assertEqual(test_function().shape, []) + def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): with context.graph_mode(): v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 90889e3e5d..2c7c4285ca 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -230,6 +230,17 @@ class DNNClassifier(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNClassifier( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = DNNClassifier( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], @@ -317,8 +328,9 @@ class DNNClassifier(estimator.Estimator): encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given @@ -385,6 +397,17 @@ class DNNRegressor(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNRegressor( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = DNNRegressor( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], @@ -465,8 +488,9 @@ class DNNRegressor(estimator.Estimator): used as a key to fetch weight tensor from the `features`. If it is a `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then weight_column.normalizer_fn is applied on it to get weight tensor. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 3d1ad1365b..2f20e4b289 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -257,12 +257,19 @@ class DNNLinearCombinedClassifier(estimator.Estimator): # warm-start settings warm_start_from="/path/to/checkpoint/dir") - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -325,12 +332,16 @@ class DNNLinearCombinedClassifier(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, @@ -441,12 +452,19 @@ class DNNLinearCombinedRegressor(estimator.Estimator): # warm-start settings warm_start_from="/path/to/checkpoint/dir") - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -508,12 +526,16 @@ class DNNLinearCombinedRegressor(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index ac59e786c4..e22df849e5 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -193,6 +193,17 @@ class LinearClassifier(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearClassifier( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = LinearClassifier( feature_columns=[categorical_column_a, @@ -272,8 +283,9 @@ class LinearClassifier(estimator.Estimator): encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. warm_start_from: A string filepath to a checkpoint to warm-start from, or @@ -335,10 +347,31 @@ class LinearRegressor(estimator.Estimator): categorical_feature_a_x_categorical_feature_b = crossed_column(...) + # Estimator using the default optimizer. estimator = LinearRegressor( feature_columns=[categorical_column_a, categorical_feature_a_x_categorical_feature_b]) + # Or estimator using the FTRL optimizer with regularization. + estimator = LinearRegressor( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=tf.train.FtrlOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001 + )) + + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearRegressor( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = LinearRegressor( feature_columns=[categorical_column_a, @@ -409,8 +442,9 @@ class LinearRegressor(estimator.Estimator): used as a key to fetch weight tensor from the `features`. If it is a `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then weight_column.normalizer_fn is applied on it to get weight tensor. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. warm_start_from: A string filepath to a checkpoint to warm-start from, or diff --git a/tensorflow/python/estimator/canned/optimizers.py b/tensorflow/python/estimator/canned/optimizers.py index f72c5ca5cb..8f51cc3a80 100644 --- a/tensorflow/python/estimator/canned/optimizers.py +++ b/tensorflow/python/estimator/canned/optimizers.py @@ -72,6 +72,8 @@ def get_optimizer_instance(opt, learning_rate=None): raise ValueError( 'Unsupported optimizer name: {}. Supported names are: {}'.format( opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES))))) + if callable(opt): + opt = opt() if not isinstance(opt, optimizer_lib.Optimizer): raise ValueError( 'The given object is not an Optimizer instance. Given: {}'.format(opt)) diff --git a/tensorflow/python/estimator/canned/optimizers_test.py b/tensorflow/python/estimator/canned/optimizers_test.py index ee28756155..eadabdbc49 100644 --- a/tensorflow/python/estimator/canned/optimizers_test.py +++ b/tensorflow/python/estimator/canned/optimizers_test.py @@ -28,6 +28,13 @@ from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import rmsprop +class _TestOptimizer(optimizer_lib.Optimizer): + + def __init__(self): + super(_TestOptimizer, self).__init__( + use_locking=False, name='TestOptimizer') + + class GetOptimizerInstance(test.TestCase): def test_unsupported_name(self): @@ -66,12 +73,6 @@ class GetOptimizerInstance(test.TestCase): self.assertAlmostEqual(0.1, opt._learning_rate) def test_object(self): - class _TestOptimizer(optimizer_lib.Optimizer): - - def __init__(self): - super(_TestOptimizer, self).__init__( - use_locking=False, name='TestOptimizer') - opt = optimizers.get_optimizer_instance(_TestOptimizer()) self.assertIsInstance(opt, _TestOptimizer) @@ -80,6 +81,23 @@ class GetOptimizerInstance(test.TestCase): ValueError, 'The given object is not an Optimizer instance'): optimizers.get_optimizer_instance((1, 2, 3)) + def test_callable(self): + def _optimizer_fn(): + return _TestOptimizer() + opt = optimizers.get_optimizer_instance(_optimizer_fn) + self.assertIsInstance(opt, _TestOptimizer) + + def test_lambda(self): + opt = optimizers.get_optimizer_instance(lambda: _TestOptimizer()) # pylint: disable=unnecessary-lambda + self.assertIsInstance(opt, _TestOptimizer) + + def test_callable_returns_invalid(self): + def _optimizer_fn(): + return (1, 2, 3) + with self.assertRaisesRegexp( + ValueError, 'The given object is not an Optimizer instance'): + optimizers.get_optimizer_instance(_optimizer_fn) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 72eb7e0eeb..699d2b70d1 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -407,11 +407,11 @@ def import_graph_def(graph_def, _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) - # _ProcessNewOps mutates the new operations. _lock ensures a Session.run - # call cannot occur between creating the TF_Operations in the + # _ProcessNewOps mutates the new operations. _mutation_lock ensures a + # Session.run call cannot occur between creating the TF_Operations in the # TF_GraphImportGraphDefWithResults call and mutating the them in # _ProcessNewOps. - with graph._lock: # pylint: disable=protected-access + with graph._mutation_lock(): # pylint: disable=protected-access with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: results = c_api.TF_GraphImportGraphDefWithResults( diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 05f9ae21b1..cf0b1e36fb 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -55,6 +55,7 @@ from tensorflow.python.platform import app from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils +from tensorflow.python.util import lock_util from tensorflow.python.util import tf_contextlib from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.tf_export import tf_export @@ -2599,6 +2600,10 @@ def _name_from_scope_name(name): return name[:-1] if (name and name[-1] == "/") else name +_MUTATION_LOCK_GROUP = 0 +_SESSION_RUN_LOCK_GROUP = 1 + + @tf_export("Graph") class Graph(object): """A TensorFlow computation, represented as a dataflow graph. @@ -2648,20 +2653,21 @@ class Graph(object): def __init__(self): """Creates a new, empty Graph.""" - # Protects core state that can be returned via public accessors, as well as - # synchronizes Session.run calls with methods that create and mutate ops - # (e.g. Graph.create_op()). This synchronization is necessary because it's - # illegal to modify an operation after it's been run. Thread-safety is - # provided on a best-effort basis to support buggy programs, and is not - # guaranteed by the public `tf.Graph` API. - # - # The lock must be reentrant because create_op can be called recursively due - # to control flow. Without a reentrant lock, many methods would also need a - # "locked" version or parameter (including generated code). + # Protects core state that can be returned via public accessors. + # Thread-safety is provided on a best-effort basis to support buggy + # programs, and is not guaranteed by the public `tf.Graph` API. # # NOTE(mrry): This does not protect the various stacks. A warning will # be reported if these are used from multiple threads self._lock = threading.RLock() + # The group lock synchronizes Session.run calls with methods that create + # and mutate ops (e.g. Graph.create_op()). This synchronization is + # necessary because it's illegal to modify an operation after it's been run. + # The group lock allows any number of threads to mutate ops at the same time + # but if any modification is going on, all Session.run calls have to wait. + # Similarly, if one or more Session.run calls are going on, all mutate ops + # have to wait until all Session.run calls have finished. + self._group_lock = lock_util.GroupLock(num_groups=2) self._nodes_by_id = dict() # GUARDED_BY(self._lock) self._next_id_counter = 0 # GUARDED_BY(self._lock) self._nodes_by_name = dict() # GUARDED_BY(self._lock) @@ -3192,9 +3198,9 @@ class Graph(object): input_ops = set([t.op for t in inputs]) control_inputs = self._control_dependencies_for_inputs(input_ops) - # _create_op_helper mutates the new Operation. _lock ensures a Session.run - # call cannot occur between creating and mutating the op. - with self._lock: + # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a + # Session.run call cannot occur between creating and mutating the op. + with self._mutation_lock(): ret = Operation( node_def, self, @@ -4727,6 +4733,20 @@ class Graph(object): else: self._graph_control_dependencies_stack = control_dependencies + def _mutation_lock(self): + """Returns a lock to guard code that creates & mutates ops. + + See the comment for self._group_lock for more info. + """ + return self._group_lock.group(_MUTATION_LOCK_GROUP) + + def _session_run_lock(self): + """Returns a lock to guard code for Session.run. + + See the comment for self._group_lock for more info. + """ + return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) + # TODO(agarwal): currently device directives in an outer eager scope will not # apply to inner graph mode code. Fix that. diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 3988238609..1b5db17ae7 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -414,8 +414,28 @@ def assert_no_new_pyobjects_executing_eagerly(f): f(self, **kwargs) gc.collect() previous_count = len(gc.get_objects()) + collection_sizes_before = { + collection: len(ops.get_collection(collection)) + for collection in ops.get_default_graph().collections} for _ in range(3): f(self, **kwargs) + # Note that gc.get_objects misses anything that isn't subject to garbage + # collection (C types). Collections are a common source of leaks, so we + # test for collection sizes explicitly. + for collection_key in ops.get_default_graph().collections: + collection = ops.get_collection(collection_key) + size_before = collection_sizes_before.get(collection_key, 0) + if len(collection) > size_before: + raise AssertionError( + ("Collection %s increased in size from " + "%d to %d (current items %s).") + % (collection_key, size_before, len(collection), collection)) + # Make sure our collection checks don't show up as leaked memory by + # removing references to temporary variables. + del collection + del collection_key + del size_before + del collection_sizes_before gc.collect() # There should be no new Python objects hanging around. new_count = len(gc.get_objects()) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 3edb8033ff..aa84eaa8ab 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect @@ -665,14 +666,14 @@ class Network(base_layer.Layer): @property def trainable_weights(self): - return layer_utils.gather_trainable_weights( + return checkpointable_layer_utils.gather_trainable_weights( trainable=self.trainable, sub_layers=self.layers, extra_variables=self._extra_variables) @property def non_trainable_weights(self): - return layer_utils.gather_non_trainable_weights( + return checkpointable_layer_utils.gather_non_trainable_weights( trainable=self.trainable, sub_layers=self.layers, extra_variables=self._extra_variables) diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 1beb0e396e..671508ab4e 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -604,6 +604,25 @@ class FunctionalOpsTest(test.TestCase): mul = sess.run(remote_op) self.assertEqual(mul, [6]) + def testRemoteFunctionSameDeviceDirectSession(self): + + @function.Defun(dtypes.int32, dtypes.int32) + def _remote_fn(a, b): + return math_ops.multiply(a, b) + + with ops.device("/cpu:0"): + a = variables.Variable(2, dtype=dtypes.int32) + b = variables.Variable(3, dtype=dtypes.int32) + + with ops.device("/cpu:0"): + remote_op = functional_ops.remote_call( + args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0") + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + mul = sess.run(remote_op) + self.assertEqual(mul, [6]) + def testRemoteFunctionCPUGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 837c144467..c8442b42d5 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -2943,9 +2943,10 @@ class WhileContext(ControlFlowContext): loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars) try: self.Enter() - # _BuildLoop calls _update_input in several places. _lock ensures a - # Session.run call cannot occur between creating and mutating new ops. - with ops.get_default_graph()._lock: # pylint: disable=protected-access + # _BuildLoop calls _update_input in several places. _mutation_lock() + # ensures a Session.run call cannot occur between creating and mutating + # new ops. + with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access original_body_result, exit_vars = self._BuildLoop( pred, body, original_loop_vars, loop_vars, shape_invariants) finally: diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 99909ac38e..250b9285c9 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -534,10 +534,10 @@ def gradients(ys, RuntimeError: if called in Eager mode. """ - # Creating the gradient graph for control flow mutates Operations. _lock - # ensures a Session.run call cannot occur between creating and mutating new - # ops. - with ops.get_default_graph()._lock: # pylint: disable=protected-access + # Creating the gradient graph for control flow mutates Operations. + # _mutation_lock ensures a Session.run call cannot occur between creating and + # mutating new ops. + with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients) diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index b80f84eb7c..00150fe688 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -306,10 +306,11 @@ def create_db_writer(db_uri, def _make_summary_writer(name, factory, **kwargs): resource = gen_summary_ops.summary_writer(shared_name=name) init_op_fn = lambda: factory(resource, **kwargs) - # TODO(apassos): Consider doing this instead. - # if not context.executing_eagerly(): - # ops.get_default_session().run(init_op) - ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op_fn()) + init_op = init_op_fn() + if not context.executing_eagerly(): + # TODO(apassos): Consider doing this instead. + # ops.get_default_session().run(init_op) + ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op) return SummaryWriter(resource, init_op_fn) @@ -380,7 +381,8 @@ def summary_writer_function(name, tensor, function, family=None): with ops.device("cpu:0"): op = smart_cond.smart_cond( should_record_summaries(), record, _nothing, name="") - ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access + if not context.executing_eagerly(): + ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access return op diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD index 9232b6089a..54f359489e 100644 --- a/tensorflow/python/training/checkpointable/BUILD +++ b/tensorflow/python/training/checkpointable/BUILD @@ -62,11 +62,18 @@ py_test( ) py_library( + name = "layer_utils", + srcs = ["layer_utils.py"], + srcs_version = "PY2AND3", +) + +py_library( name = "data_structures", srcs = ["data_structures.py"], srcs_version = "PY2AND3", deps = [ ":base", + ":layer_utils", ], ) diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index 680cf3441f..c46585b417 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -21,10 +21,9 @@ import collections import six -from tensorflow.python.keras.engine import base_layer -from tensorflow.python.keras.utils import layer_utils from tensorflow.python.ops import variables from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import layer_utils # TODO(allenl): We could track regular Python data structures which get assigned @@ -54,7 +53,8 @@ class CheckpointableDataStructure(checkpointable_lib.CheckpointableBase): ("Only checkpointable objects (such as Layers or Optimizers) may be " "stored in a List object. Got %s, which does not inherit from " "CheckpointableBase.") % (value,)) - if isinstance(value, (base_layer.Layer, CheckpointableDataStructure)): + if (isinstance(value, CheckpointableDataStructure) + or layer_utils.is_layer(value)): if value not in self._layers: self._layers.append(value) if hasattr(value, "_use_resource_variables"): diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py new file mode 100644 index 0000000000..fdcf963d32 --- /dev/null +++ b/tensorflow/python/training/checkpointable/layer_utils.py @@ -0,0 +1,85 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities related to layer/model functionality.""" + +# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils +# once __init__ files no longer require all of tf.keras to be imported together. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def is_layer(obj): + """Implicit check for Layer-like objects.""" + # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). + return (hasattr(obj, "call") + and hasattr(obj, "build") + and hasattr(obj, "variables")) + + +def gather_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected trainable weights/variables. + """ + if not trainable: + return [] + weights = [] + for layer in sub_layers: + weights += layer.trainable_weights + trainable_extra_variables = [ + v for v in extra_variables if v.trainable] + return weights + trainable_extra_variables + + +def gather_non_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the non-trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected non-trainable weights/variables. + """ + trainable_extra_variables = [] + non_trainable_extra_variables = [] + for v in extra_variables: + if v.trainable: + trainable_extra_variables.append(v) + else: + non_trainable_extra_variables.append(v) + weights = [] + for layer in sub_layers: + weights += layer.non_trainable_weights + if not trainable: + trainable_weights = [] + for layer in sub_layers: + trainable_weights += layer.trainable_weights + return (trainable_weights + trainable_extra_variables + + weights + non_trainable_extra_variables) + return weights + non_trainable_extra_variables diff --git a/tensorflow/python/util/lock_util_test.py b/tensorflow/python/util/lock_util_test.py index 2ac640ff99..cda8f95225 100644 --- a/tensorflow/python/util/lock_util_test.py +++ b/tensorflow/python/util/lock_util_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import random -import threading import time from absl.testing import parameterized @@ -48,7 +47,7 @@ class GroupLockTest(test.TestCase, parameterized.TestCase): finished.add(thread_id) threads = [ - threading.Thread(target=thread_fn, args=(i,)) + self.checkedThread(target=thread_fn, args=(i,)) for i in range(num_threads) ] diff --git a/tensorflow/tools/api/generator/doc_srcs_test.py b/tensorflow/tools/api/generator/doc_srcs_test.py index 7b8f27c1b1..dbff904abe 100644 --- a/tensorflow/tools/api/generator/doc_srcs_test.py +++ b/tensorflow/tools/api/generator/doc_srcs_test.py @@ -39,27 +39,27 @@ class DocSrcsTest(test.TestCase): file_path += '/' file_path += '__init__.py' - if file_path not in FLAGS.outputs: - self.assertFalse('%s is not a valid API module' % module_name) + self.assertIn( + file_path, FLAGS.outputs, + msg='%s is not a valid API module' % module_name) def testHaveDocstringOrDocstringModule(self): for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items(): - if docsrc.docstring and docsrc.docstring_module_name: - self.assertFalse( - '%s contains DocSource has both a docstring and a ' - 'docstring_module_name. ' - 'Only one of "docstring" or "docstring_module_name" should be set.' - % (module_name)) + self.assertFalse( + docsrc.docstring and docsrc.docstring_module_name, + msg=('%s contains DocSource has both a docstring and a ' + 'docstring_module_name. Only one of "docstring" or ' + '"docstring_module_name" should be set.') % (module_name)) def testDocstringModulesAreValidModules(self): for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items(): if docsrc.docstring_module_name: doc_module_name = '.'.join([ FLAGS.package, docsrc.docstring_module_name]) - if doc_module_name not in sys.modules: - self.assertFalse( - 'docsources_module %s is not a valid module under %s.' % - (docsrc.docstring_module_name, FLAGS.package)) + self.assertIn( + doc_module_name, sys.modules, + msg=('docsources_module %s is not a valid module under %s.' % + (docsrc.docstring_module_name, FLAGS.package))) if __name__ == '__main__': diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index eea712c279..2403e2d966 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -39,6 +39,7 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/python:platform", + "//tensorflow/python:util", "@astor_archive//:astor", ], ) @@ -95,6 +96,7 @@ py_binary( deps = [ ":generate_lib", "//tensorflow:tensorflow_py", + "//tensorflow/python:util", "//tensorflow/python/debug:debug_py", ], ) diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 67c413cccb..e7634cd5dc 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -388,16 +388,40 @@ def _build_guide_index(guide_src_dir): class _UpdateTags(py_guide_parser.PyGuideParser): - """Rewrites a Python guide so that each section has an explicit tag.""" + """Rewrites a Python guide so that each section has an explicit id tag. + + "section" here refers to blocks delimited by second level headings. + """ def process_section(self, line_number, section_title, tag): self.replace_line(line_number, '<h2 id="%s">%s</h2>' % (tag, section_title)) +def update_id_tags_inplace(src_dir): + """Set explicit ids on all second-level headings to ensure back-links work. + + Args: + src_dir: The directory of md-files to convert (inplace). + """ + tag_updater = _UpdateTags() + + for dirpath, _, filenames in os.walk(src_dir): + for base_name in filenames: + if not base_name.endswith('.md'): + continue + full_path = os.path.join(src_dir, dirpath, base_name) + + # Tag updater loads the file, makes the replacements, and returns the + # modified file contents + content = tag_updater.process(full_path) + with open(full_path, 'w') as f: + f.write(content) + + EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt']) -def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): +def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): """Fix @{} references in all files under `src_dir` matching `file_pattern`. A matching directory structure, with the modified files is @@ -418,7 +442,6 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): using fnmatch. Non-matching files are copied unchanged. """ # Iterate through all the source files and process them. - tag_updater = _UpdateTags() for dirpath, _, filenames in os.walk(src_dir): # How to get from `dirpath` to api_docs/python/ relative_path_to_root = os.path.relpath( @@ -435,24 +458,25 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): continue full_in_path = os.path.join(dirpath, base_name) + # Set the `current_doc_full_name` so bad files can be reported on errors. reference_resolver.current_doc_full_name = full_in_path suffix = os.path.relpath(path=full_in_path, start=src_dir) full_out_path = os.path.join(output_dir, suffix) + # Copy files that do not match the file_pattern, unmodified. if not fnmatch.fnmatch(base_name, file_pattern): shutil.copyfile(full_in_path, full_out_path) continue - if dirpath.endswith('/api_guides/python'): - content = tag_updater.process(full_in_path) - else: - with open(full_in_path, 'rb') as f: - content = f.read().decode('utf-8') + + with open(full_in_path, 'rb') as f: + content = f.read().decode('utf-8') content = reference_resolver.replace_references(content, relative_path_to_root) with open(full_out_path, 'wb') as f: f.write(content.encode('utf-8')) + class DocGenerator(object): """Main entry point for generating docs.""" @@ -538,15 +562,43 @@ class DocGenerator(object): self._do_not_descend_map) def build(self, flags): - """Actually build the docs.""" + """Build all the docs. + + This produces two outputs + + python api docs: + + * generated from modules set with `set_py_modules`. + * written to '{FLAGS.output_dir}/api_docs/python/' + + non-api docs: + + * Everything in '{FLAGS.src_dir}' is copied to '{FLAGS.output_dir}'. + * '@{}' references in '.md' files are replaced with links. + * '.md' files under 'api_guides/python' have explicit ids set for their + second level headings. + + Args: + flags: + * src_dir: Where to fetch the non-api-docs. + * base_dir: Base of the docs directory (Used to build correct + relative links). + * output_dir: Where to write the resulting docs. + + Returns: + The number of errors encountered while processing. + """ + # Extract the python api from the _py_modules doc_index = build_doc_index(flags.src_dir) visitor = self.run_extraction() reference_resolver = self.make_reference_resolver(visitor, doc_index) + # Build the guide_index for the api_docs back links. root_title = getattr(flags, 'root_title', 'TensorFlow') guide_index = _build_guide_index( os.path.join(flags.src_dir, 'api_guides/python')) + # Write the api docs. parser_config = self.make_parser_config(visitor, reference_resolver, guide_index, flags.base_dir) output_dir = os.path.join(flags.output_dir, 'api_docs/python') @@ -557,8 +609,16 @@ class DocGenerator(object): yaml_toc=self.yaml_toc, root_title=root_title, search_hints=getattr(flags, 'search_hints', True)) - _other_docs(flags.src_dir, flags.output_dir, reference_resolver) + # Replace all the @{} references in files under `FLAGS.src_dir` + replace_refs(flags.src_dir, flags.output_dir, reference_resolver, '*.md') + # Fix the tags in the guide dir. + guide_dir = os.path.join(flags.output_dir, 'api_guides/python') + if os.path.exists(guide_dir): + update_id_tags_inplace(guide_dir) + + # Report all errors found by the reference resolver, and return the error + # code. parser_config.reference_resolver.log_errors() return parser_config.reference_resolver.num_errors() diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py index ea6d28a02b..7a6f9fd9f7 100644 --- a/tensorflow/tools/docs/generate_lib_test.py +++ b/tensorflow/tools/docs/generate_lib_test.py @@ -51,7 +51,9 @@ class DummyVisitor(object): class GenerateTest(googletest.TestCase): - def test_write(self): + def get_test_objects(self): + # These are all mutable objects, so rebuild them for each test. + # Don't cache the objects. module = sys.modules[__name__] index = { @@ -98,6 +100,11 @@ class GenerateTest(googletest.TestCase): guide_index={}, base_dir=base_dir) + return reference_resolver, parser_config + + def test_write(self): + _, parser_config = self.get_test_objects() + output_dir = googletest.GetTempDir() generate_lib.write_docs(output_dir, parser_config, yaml_toc=True) @@ -127,6 +134,107 @@ class GenerateTest(googletest.TestCase): os.path.exists( os.path.join(output_dir, 'tf/TestModule/test_function.md'))) + def test_update_id_tags_inplace(self): + test_dir = googletest.GetTempDir() + test_sub_dir = os.path.join(test_dir, 'a/b') + os.makedirs(test_sub_dir) + + test_path1 = os.path.join(test_dir, 'file1.md') + test_path2 = os.path.join(test_sub_dir, 'file2.md') + test_path3 = os.path.join(test_sub_dir, 'file3.notmd') + + with open(test_path1, 'w') as f: + f.write('## abc&123') + + with open(test_path2, 'w') as f: + f.write('# A Level 1 Heading\n') + f.write('## A Level 2 Heading') + + with open(test_path3, 'w') as f: + f.write("## don\'t change this") + + generate_lib.update_id_tags_inplace(test_dir) + + with open(test_path1) as f: + content = f.read() + + self.assertEqual(content, '<h2 id="abc_123">abc&123</h2>') + + with open(test_path2) as f: + content = f.read() + + self.assertEqual( + content, '# A Level 1 Heading\n' + '<h2 id="A_Level_2_Heading">A Level 2 Heading</h2>') + + with open(test_path3) as f: + content = f.read() + + self.assertEqual(content, "## don\'t change this") + + def test_replace_refes(self): + test_dir = googletest.GetTempDir() + test_in_dir = os.path.join(test_dir, 'in') + test_in_dir_a = os.path.join(test_dir, 'in/a') + test_in_dir_b = os.path.join(test_dir, 'in/b') + os.makedirs(test_in_dir) + os.makedirs(test_in_dir_a) + os.makedirs(test_in_dir_b) + + test_out_dir = os.path.join(test_dir, 'out') + os.makedirs(test_out_dir) + + test_path1 = os.path.join(test_in_dir_a, 'file1.md') + test_path2 = os.path.join(test_in_dir_b, 'file2.md') + test_path3 = os.path.join(test_in_dir_b, 'file3.notmd') + test_path4 = os.path.join(test_in_dir_b, 'OWNERS') + + with open(test_path1, 'w') as f: + f.write('Use `tf.test_function` to test things.') + + with open(test_path2, 'w') as f: + f.write('Use @{tf.TestModule.TestClass.ChildClass} to test things.\n' + "`tf.whatever` doesn't exist") + + with open(test_path3, 'w') as f: + file3_content = ( + 'Not a .md file. Should be copied unchanged:' + '@{tf.TestModule.TestClass.ChildClass}, `tf.test_function`') + f.write(file3_content) + + with open(test_path4, 'w') as f: + f.write('') + + reference_resolver, _ = self.get_test_objects() + generate_lib.replace_refs(test_in_dir, test_out_dir, reference_resolver, + '*.md') + + with open(os.path.join(test_out_dir, 'a/file1.md')) as f: + content = f.read() + self.assertEqual( + content, + 'Use <a href="../api_docs/python/tf/TestModule/test_function.md">' + '<code>tf.test_function</code></a> to test things.') + + with open(os.path.join(test_out_dir, 'b/file2.md')) as f: + content = f.read() + self.assertEqual( + content, + 'Use ' + '<a href="../api_docs/python/tf/TestModule/TestClass/ChildClass.md">' + '<code>tf.TestModule.TestClass.ChildClass</code></a> ' + 'to test things.\n' + '`tf.whatever` doesn\'t exist') + + with open(os.path.join(test_out_dir, 'b/file3.notmd')) as f: + content = f.read() + self.assertEqual(content, file3_content) + + with self.assertRaises(IOError): + # This should fail. The OWNERS file should not be copied + with open(os.path.join(test_out_dir, 'b/OWNERS')) as f: + content = f.read() + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 5cefe37782..7e4676e522 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -452,11 +452,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7f7cea53068238fca7b7e4299793a0c77bea7219.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/7f7cea53068238fca7b7e4299793a0c77bea7219.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/8a152c54c401f9a9370bedf05049ac5b847bc965.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/8a152c54c401f9a9370bedf05049ac5b847bc965.tar.gz", ], - sha256 = "b645507080e07c845607f212d45e4ee79253c3c9b762531f51fbaeceb6b47391", - strip_prefix = "llvm-7f7cea53068238fca7b7e4299793a0c77bea7219", + sha256 = "dad37678abffa4f3001b1789a89f64f245bc50721f8d37b4f8b31b0695e90015", + strip_prefix = "llvm-8a152c54c401f9a9370bedf05049ac5b847bc965", build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), ) |