diff options
author | 2017-11-30 21:15:53 -0800 | |
---|---|---|
committer | 2017-11-30 21:15:53 -0800 | |
commit | 2c4e8fcf05d3e22b0758a6f63a423b9319f9c19d (patch) | |
tree | 5d58a4760b28af0f7ca22b2620a9fb6fc940d335 | |
parent | c57796f366a0545a04424caeff1b27bbd629f8f0 (diff) | |
parent | 1ec61fafe13e5edce6e45d5a67e960efb9df618a (diff) |
Fix merge conflicts
407 files changed, 10880 insertions, 3826 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index bb41f92306..c8b4bfffd4 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -383,12 +383,11 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, // be less than the total node count. Status ValidateNoCycles(const Graph& g) { // TODO(nolivia): check this on a subset of the graph instead of all of it. - int total_num_nodes = g.num_node_ids(); // A node is ready when all of its inputs have been visited. std::vector<const Node*> ready; - std::vector<int> pending_count(total_num_nodes, 0); + std::vector<int> pending_count(g.num_node_ids(), 0); - for (int i = 0; i < total_num_nodes; ++i) { + for (int i = 0; i < g.num_node_ids(); ++i) { const Node* n = g.FindNodeId(i); if (n == nullptr) continue; pending_count[i] = n->in_edges().size(); @@ -421,7 +420,7 @@ Status ValidateNoCycles(const Graph& g) { } } - if (processed < total_num_nodes) { + if (processed < g.num_nodes()) { std::vector<string> nodes_in_cycle; for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; ++i) { @@ -430,7 +429,7 @@ Status ValidateNoCycles(const Graph& g) { } } return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", total_num_nodes - processed, + "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); } return Status::OK(); @@ -625,6 +624,23 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, return Status::OK(); } +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + // If any session has already run this node_id, mark this session as + // unrunnable. + for (auto it : graph->sessions) { + if (it.first->last_num_graph_nodes > op.node.id()) { + it.second = FailedPrecondition( + "Operation '", op.node.DebugString(), "' was changed by ", + mutation_type, + " after it was run by a session. Nodes can be mutated " + "only before they are executed by a session. Either don't modify " + "nodes after running them or create a new session."); + } + } +} + // Helpers for loading a TensorFlow plugin (a .so file). Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); @@ -1745,7 +1761,6 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, TF_Graph::TF_Graph() : graph(tensorflow::OpRegistry::Global()), refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), delete_requested(false), parent(nullptr), parent_inputs(nullptr) {} @@ -1755,7 +1770,7 @@ TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { g->mu.lock(); g->delete_requested = true; - const bool del = g->num_sessions == 0; + const bool del = g->sessions.empty(); g->mu.unlock(); if (del) delete g; } @@ -2325,11 +2340,12 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, Session* session; status->status = NewSession(opt->options, &session); if (status->status.ok()) { + TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); - graph->num_sessions += 1; + graph->sessions[new_session] = Status::OK(); } - return new TF_Session(session, graph); + return new_session; } else { DCHECK_EQ(nullptr, session); return nullptr; @@ -2393,7 +2409,7 @@ TF_Session* TF_LoadSessionFromSavedModel( TF_Session* session = new TF_Session(bundle.session.release(), graph); - graph->num_sessions += 1; + graph->sessions[session] = Status::OK(); session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; #endif // __ANDROID__ @@ -2408,8 +2424,8 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) { TF_Graph* const graph = s->graph; if (graph != nullptr) { graph->mu.lock(); - graph->num_sessions -= 1; - const bool del = graph->delete_requested && graph->num_sessions == 0; + graph->sessions.erase(s); + const bool del = graph->delete_requested && graph->sessions.empty(); graph->mu.unlock(); if (del) delete graph; } @@ -2425,6 +2441,13 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { mutex_lock session_lock(session->mu); session->graph->mu.lock(); const Graph& graph = session->graph->graph; + + status->status = session->graph->sessions[session]; + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { status->status = tensorflow::ValidateNoCycles(session->graph->graph); diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index bb04e01bee..aac333d9e2 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -81,12 +81,20 @@ struct TF_Graph { std::unordered_map<tensorflow::string, tensorflow::Node*> name_map GUARDED_BY(mu); - // TF_Graph may only / must be deleted when - // num_sessions == 0 && delete_requested == true - - // num_sessions incremented by TF_NewSession, and decremented by + // The keys of this map are all the active sessions using this graph. + // Each value is the current "runnability" status of the corresponding + // session. Under normal conditions all statuses are Status::OK(), but + // if some operation is mutated after it was run by a session (this + // is detected in RecordMutation function), that session is no longer + // safe to run. Its status will contain the error that will be returned + // to the user, should she try running this session. + // + // Sessions are added to this map in TF_NewSession, and removed in // TF_DeleteSession. - int num_sessions GUARDED_BY(mu); + // TF_Graph may only / must be deleted when + // sessions.size() == 0 && delete_requested == true + tensorflow::gtl::FlatMap<TF_Session*, tensorflow::Status> sessions + GUARDED_BY(mu); bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph // Used to link graphs contained in TF_WhileParams to the parent graph that @@ -167,6 +175,9 @@ TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index f52248e7d5..191e9c3413 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -161,7 +161,7 @@ class GradientTape { // the tape refer to it); to aid in tape garbage collection. std::unordered_map<int64, int64> tensor_usage_; - // If true, all activations are deleted in the first call to ComputeGradient. + // If false, all activations are deleted in the first call to ComputeGradient. // Else, only when this is destructed. bool persistent_; }; diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index ba5a9268b4..37629a74ba 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -22,6 +22,7 @@ namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { mutex_lock l(graph->mu); graph->graph.AddControlEdge(&input->node, &op->node); + RecordMutation(graph, *op, "adding control input"); } void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, @@ -36,11 +37,13 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, mutex_lock l(graph->mu); op->node.AddAttr(attr_name, attr_val); + RecordMutation(graph, *op, "setting attribute"); } void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { mutex_lock l(graph->mu); op->node.set_requested_device(device); + RecordMutation(graph, *op, "setting device"); } void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, @@ -75,6 +78,13 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, } status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); + + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst.oper, "updating input tensor"); + } } } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index ae22f7edc4..28ac40df18 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -418,7 +418,7 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); {{NS_START}} // {{CLASS}} represents a computation previously specified in a @@ -483,7 +483,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} {{CLASS}}(const {{CLASS}}&) = delete; @@ -496,8 +496,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 65f342ce27..cf01bee325 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -19,7 +19,7 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); namespace foo { namespace bar { @@ -86,7 +86,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} MyClass(const MyClass&) = delete; @@ -99,8 +99,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 6b037f276a..413efd9cea 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -70,7 +70,7 @@ TEST(TFCompileTest, Add) { // Run tests that use set_argN_data separately, to avoid accidentally re-using // non-existent buffers. TEST(TFCompileTest, Add_SetArg) { - AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); int32 arg_x = 10; int32 arg_y = 32; @@ -258,7 +258,7 @@ TEST(TFCompileTest, MatMul2_SetArg) { Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); foo::bar::MatMulComp matmul( - foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); matmul.set_thread_pool(&device); // Test using the set_argN_data() methods. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 74c9791f5e..aceedeb823 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -210,6 +210,13 @@ Status FindCompilationCandidates( !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { continue; } + // _Retval nodes in a top-level function represent fetches. + // Do not compile them. + if (node->type_string() == "_Retval") { + VLOG(2) << "Compilation rejected node: return value " << node->name() + << ": " << node->type_string(); + continue; + } candidates->insert(node); } return Status::OK(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index b3d258aea1..454f0aeae9 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -525,5 +525,32 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { "+-- c\n")); } +TEST(XlaCompilationTest, Retval) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::UnaryOp("_Retval", b, + builder.opts() + .WithName("R") + .WithAttr("T", DT_FLOAT) + .WithAttr("index", 0)); + + TF_EXPECT_OK(builder.ToGraph(graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + + EXPECT_EQ(2, clusters.size()); + EXPECT_TRUE(clusters.find("R") == clusters.cend()); + EXPECT_EQ(clusters["A"], clusters["B"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 6cad2b0824..fff1a7f57b 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -417,6 +417,20 @@ tf_xla_py_test( ) tf_xla_py_test( + name = "scan_ops_test", + size = "small", + srcs = ["scan_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( name = "segment_reduction_ops_test", size = "medium", srcs = ["segment_reduction_ops_test.py"], diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 5e06f9a724..035cdea178 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -35,6 +35,9 @@ from tensorflow.python.platform import googletest class CategoricalTest(XLATestCase): """Test cases for random-number generating operators.""" + def output_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + def _chi2(self, expected, actual): """Returns Chi2 GOF statistic.""" actual = np.asarray(actual) @@ -55,7 +58,8 @@ class CategoricalTest(XLATestCase): """ with self.test_session() as sess, self.test_scope(): random_seed.set_random_seed(1618) - op = random_ops.multinomial(logits, num_samples) + op = random_ops.multinomial(logits, num_samples, + output_dtype=dtypes.int32) d = sess.run(op) batch_size, num_classes = logits.shape @@ -73,11 +77,11 @@ class CategoricalTest(XLATestCase): return freqs_mat - def _testRngIsNotConstant(self, rng, dtype): + def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. with self.test_session() as sess: with self.test_scope(): - x = rng(dtype) + x = rng(dtype, output_dtype) # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. @@ -92,21 +96,25 @@ class CategoricalTest(XLATestCase): (not np.array_equal(y, w))) def testCategoricalIsNotConstant(self): - def rng(unused_dtype): - return random_ops.multinomial([[1., 1., 1.]], 10) + def rng(dtype, output_dtype): + return random_ops.multinomial(np.array([[1., 1., 1.]], dtype=dtype), 10, + output_dtype=output_dtype) - dtype = dtypes.float32 - self._testRngIsNotConstant(rng, dtype) + dtype = np.float32 + for output_dtype in self.output_dtypes(): + self._testRngIsNotConstant(rng, dtype, output_dtype) def testCategoricalIsInRange(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session() as sess: - with self.test_scope(): - x = random_ops.multinomial( - array_ops.ones(shape=[1, 20], dtype=dtype), 1000) - y = sess.run(x) - self.assertTrue((y >= 0).sum() == 1000) - self.assertTrue((y < 20).sum() == 1000) + for dtype in self.float_types: + for output_dtype in self.output_dtypes(): + with self.test_session() as sess: + with self.test_scope(): + x = random_ops.multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), 1000, + output_dtype=output_dtype) + y = sess.run(x) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) def testSamplingCorrectness(self): np.random.seed(1618) # Make it reproducible. diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py new file mode 100644 index 0000000000..3260e63b23 --- /dev/null +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -0,0 +1,229 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for scan ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def numpy_reverse(x, axis): + length = len(x.shape) + if axis < 0: + axis = length + axis + + ix = [ + slice(None, None, -1) if i == axis else slice(None) for i in range(length) + ] + return x[ix] + + +def handle_options(func, x, axis, exclusive, reverse): + """Adds tf options to numpy scan ops.""" + length = len(x.shape) + if axis < 0: + axis = length + axis + + if reverse: + x = numpy_reverse(x, axis) + + if exclusive: + ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)] + ix_init = [ + slice(0, -1) if i == axis else slice(None) for i in range(length) + ] + if func == np.cumsum: + init = np.zeros_like(x[ix_head]) + elif func == np.cumprod: + init = np.ones_like(x[ix_head]) + else: + raise ValueError("Unknown scan function.") + x = np.concatenate([init, func(x[ix_init], axis)], axis=axis) + else: + x = func(x, axis=axis) + + if reverse: + x = numpy_reverse(x, axis) + return x + + +class CumsumTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( + feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumsum(p, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 10).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumsum(input_tensor, [0]).eval() + + +class CumprodTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + prod = math_ops.cumprod(p, axis, exclusive, reverse) + tf_out = prod.eval(feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumprod(x, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 11).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumprod(input_tensor, [0]).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index d57273d844..6a1a5467e0 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -52,6 +52,8 @@ Status BackwardsConstAnalysis(const Graph& g, {"Conv2DBackpropInput", "input_sizes"}, {"Conv3DBackpropFilterV2", "filter_sizes"}, {"Conv3DBackpropInputV2", "input_sizes"}, + {"Cumprod", "axis"}, + {"Cumsum", "axis"}, {"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"}, {"DepthwiseConv2dNativeBackpropInput", "input_sizes"}, {"DynamicStitch", "indices"}, diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index ddd912b873..03603ee9ba 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -63,7 +63,12 @@ string MakeUniquePath(string name) { string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, graph_def)); + Status status = WriteTextProto(Env::Default(), path, graph_def); + if (!status.ok()) { + VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status; + path.clear(); + path = "(unavailable)"; + } return path; } @@ -79,7 +84,13 @@ string DumpGraphToFile(const string& name, Graph const& graph, string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, fdef)); + Status status = WriteTextProto(Env::Default(), path, fdef); + if (!status.ok()) { + VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : " + << status; + path.clear(); + path = "(unavailable)"; + } return path; } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 5726d8294a..267268298c 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -1067,6 +1067,10 @@ FunctionalizeCond::CreateCorrespondingMergeCluster(Cluster* switch_cluster) { enqueue_or_update_merge(out); } } + // Return if there are no merge nodes. + if (merges.empty()) { + return gtl::nullopt; + } auto it = merges.begin(); Cluster* merge_cluster = *it; for (++it; it != merges.end(); ++it) { diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 6302fece1f..a1720ff919 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -54,6 +54,7 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "scan_ops.cc", "segment_reduction_ops.cc", "select_op.cc", "sendrecv_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 248e9d111e..468af34aab 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // XLA implementation of BatchNorm operations. -#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -42,27 +42,44 @@ class FusedBatchNormOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(0), &input_type)); + xla::PrimitiveType stats_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(1), &stats_type)); + + xla::ComputationBuilder* builder = ctx->builder(); + + xla::ComputationDataHandle input = ctx->Input(0); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. As a workaround, cast everything to the statistics type (which + // may be more precise than the input type). + input = builder->ConvertElementType(input, stats_type); + if (is_training_) { - xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining( - ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, - feature_index_); + xla::ComputationDataHandle output = builder->BatchNormTraining( + input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index_); // In training mode, outputs the normalized value as well as the // calculated mean and variance. - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); - } + ctx->SetOutput(0, builder->ConvertElementType( + builder->GetTupleElement(output, 0), input_type)); + ctx->SetOutput(1, builder->GetTupleElement(output, 1)); + ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + ctx->SetOutput(3, builder->GetTupleElement(output, 1)); + ctx->SetOutput(4, builder->GetTupleElement(output, 2)); } else { - xla::ComputationDataHandle output = ctx->builder()->BatchNormInference( - ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3), - ctx->Input(4), epsilon_, feature_index_); - ctx->SetOutput(0, output); + xla::ComputationDataHandle output = builder->BatchNormInference( + input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), + epsilon_, feature_index_); + ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -78,6 +95,7 @@ class FusedBatchNormOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); +REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); class FusedBatchNormGradOp : public XlaOpKernel { public: @@ -101,19 +119,36 @@ class FusedBatchNormGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + auto grad_output = ctx->Input(0); auto activation = ctx->Input(1); auto scale = ctx->Input(2); auto mean = ctx->Input(3); auto var = ctx->Input(4); - xla::ComputationDataHandle output = ctx->builder()->BatchNormGrad( + + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(0), &input_type)); + xla::PrimitiveType stats_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(3), &stats_type)); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. As a workaround, cast everything to the statistics type (which + // may be more precise than the input type). + grad_output = builder->ConvertElementType(grad_output, stats_type); + activation = builder->ConvertElementType(activation, stats_type); + + xla::ComputationDataHandle output = builder->BatchNormGrad( activation, scale, mean, var, grad_output, epsilon_, feature_index_); - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); - } - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + ctx->SetOutput(0, builder->ConvertElementType( + builder->GetTupleElement(output, 0), input_type)); + ctx->SetOutput(1, builder->GetTupleElement(output, 1)); + ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + ctx->SetOutput(3, builder->GetTupleElement(output, 1)); + ctx->SetOutput(4, builder->GetTupleElement(output, 2)); } private: @@ -122,6 +157,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); +REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index c5017704e2..aaddbe811c 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -46,72 +46,130 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution( return expanded_shape; } +// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. +xla::ComputationDataHandle CreateExpandedZero( + const TensorShape& filter_shape, DataType dtype, + xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + return builder->Broadcast(XlaHelpers::Zero(builder, dtype), + expanded_filter_shape.dim_sizes()); +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tesnsor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::ComputationDataHandle CreateExpandedFilterMask( + const TensorShape& filter_shape, xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::ComputationDataHandle input_feature_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, + &input_feature_iota)); + xla::ComputationDataHandle expanded_feature_iota; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, + input_feature * depthwise_multiplier, + &expanded_feature_iota)); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + builder->Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = builder->Broadcast( + expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dims() - 2}); +} + // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding // zeros for the cross-depth filters. Used to build a depthwise convolution. xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter, xla::ComputationBuilder* builder) { - // Filter has shape [H, W, ..., M, N] - // Dilate to [H, W, ..., M*M, N] using M inter-element padding, and then - // reshape to [H, W, ..., M, M*N]. - int num_spatial_dims = filter_shape.dims() - 2; - const int64 in_depth = filter_shape.dim_size(num_spatial_dims); - xla::PaddingConfig padding = xla::MakeNoPaddingConfig(filter_shape.dims()); - padding.mutable_dimensions(num_spatial_dims)->set_interior_padding(in_depth); - auto dilated_filter = - builder->Pad(filter, XlaHelpers::Zero(builder, dtype), padding); - + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return builder->Reshape(dilated_filter, expanded_filter_shape.dim_sizes()); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + TensorShape implicit_broadcast_filter_shape = expanded_filter_shape; + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 2, 1); + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 1, + depthwise_multiplier * input_feature); + auto implicit_broadcast_filter = + builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); + + // Broadcast the filter to [H, W, ..., M, M*N]. + auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); + auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); + + // If the filter mask is set, choose the broadcasted filter, othwerwise, + // choose zero. + return builder->Select(CreateExpandedFilterMask(filter_shape, builder), + expanded_filter, expanded_zero); } // Inverse of ExpandFilterForDepthwiseConvolution. xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( - const TensorShape& filter_shape, DataType dtype, + XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter_backprop, xla::ComputationBuilder* builder) { - int num_spatial_dims = filter_shape.dims() - 2; - - // Reshape to [H, W, ..., M*M, N] - TensorShape shape = filter_shape; - int64 in_depth = filter_shape.dim_size(num_spatial_dims); - shape.set_dim(num_spatial_dims, in_depth * in_depth); - auto reshaped = builder->Reshape(filter_backprop, shape.dim_sizes()); - - std::vector<int64> zeros(filter_shape.dims()); - std::vector<int64> strides(filter_shape.dims(), 1LL); - strides[num_spatial_dims] = in_depth + 1; - return builder->Slice(reshaped, zeros, shape.dim_sizes(), strides); - - // Alternate implementation for backends without strided Slice() support. - // TODO(phawkins): Remove when all backends support strided slice. - // // Pad [..., M * (M + 1), N] - // xla::PaddingConfig config = - // xla::MakeNoPaddingConfig(filter_shape.dims()); - // config.mutable_dimensions(num_spatial_dims) - // ->set_edge_padding_high(in_depth); - // auto zero = XlaHelpers::Zero(builder, dtype); - // auto padded = builder->Pad(reshaped, zero, config); - // - // // Reshape to [..., M, M + 1, N] - // shape = filter_shape; - // shape.set_dim(num_spatial_dims, in_depth); - // shape.set_dim(num_spatial_dims + 1, in_depth + 1); - // int64 out_depth = filter_shape.dim_size(num_spatial_dims + 1); - // shape.AddDim(out_depth); - // reshaped = builder->Reshape(padded, shape.dim_sizes()); - // - // // Slice to [..., M, 1, N] - // std::vector<int64> zeros(shape.dims()); - // std::vector<int64> strides(shape.dims(), 1LL); - // shape.set_dim(num_spatial_dims + 1, 1); - // auto sliced = builder->Slice(reshaped, zeros, shape.dim_sizes(), - // strides); - // - // // Reshape to [..., M, N] - // return builder->Reshape(sliced, filter_shape.dim_sizes()); + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + auto masked_expanded_filter = builder->Select( + CreateExpandedFilterMask(filter_shape, builder), filter_backprop, + CreateExpandedZero(filter_shape, dtype, builder)); + return builder->Reshape( + builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), + *ctx->GetOrCreateAdd(dtype), + {expanded_filter_shape.dims() - 2}), + filter_shape.dim_sizes()); } class ConvOp : public XlaOpKernel { @@ -121,6 +179,7 @@ class ConvOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); @@ -144,6 +203,23 @@ class ConvOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + const TensorShape input_shape = ctx->InputShape(0); // Input filter is of the following dimensions: // [ filter_rows, filter_cols, ..., in_depth, out_depth] @@ -184,7 +260,7 @@ class ConvOp : public XlaOpKernel { dims.set_input_feature_dimension(feature_dim); dims.set_output_feature_dimension(feature_dim); for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); dims.add_input_spatial_dimensions(dim); dims.add_kernel_spatial_dimensions(i); dims.add_output_spatial_dimensions(dim); @@ -204,6 +280,7 @@ class ConvOp : public XlaOpKernel { protected: const int num_spatial_dims_; const bool depthwise_; + std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -241,6 +318,7 @@ class ConvBackpropInputOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -263,6 +341,23 @@ class ConvBackpropInputOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -336,6 +431,7 @@ class ConvBackpropInputOp : public XlaOpKernel { protected: const int num_spatial_dims_; const bool depthwise_; + std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -373,6 +469,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -392,6 +489,23 @@ class ConvBackpropFilterOp : public XlaOpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + const TensorShape activations_shape = ctx->InputShape(0); TensorShape filter_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); @@ -426,9 +540,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); - dnums.set_output_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); - dnums.set_output_feature_dimension(n_dim); // The gradients become the RHS of the convolution. // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] @@ -440,11 +552,17 @@ class ConvBackpropFilterOp : public XlaOpKernel { std::vector<int64> rhs_dilation(num_spatial_dims_); std::vector<int64> ones(num_spatial_dims_, 1); + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. + for (int i = 0; i < num_spatial_dims_; ++i) { + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(num_spatial_dims_); + dnums.set_output_feature_dimension(num_spatial_dims_ + 1); + for (int i = 0; i < num_spatial_dims_; ++i) { int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); - dnums.add_output_spatial_dimensions(dim); // We will also need to pad the input with zeros such that after the // convolution, we get the right size for the filter. @@ -501,31 +619,17 @@ class ConvBackpropFilterOp : public XlaOpKernel { /*window_strides=*/ones, padding, /*lhs_dilation=*/ones, rhs_dilation, dnums); - // The layout of filter_backprop will match the layout of - // padded_activations - // and so will have layout: [out_feature, h, w, ..., in_feature] - // Tensorflow filter shape is [ H, W, ..., inC, outC ], so we transpose the - // output. - std::vector<int64> transpose_dims; - transpose_dims.reserve(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - transpose_dims.push_back(dnums.output_spatial_dimensions(i)); - } - transpose_dims.push_back(c_dim); - transpose_dims.push_back(n_dim); - xla::ComputationDataHandle filter_backprop_reshaped = - b->Transpose(filter_backprop, transpose_dims); - if (depthwise_) { - filter_backprop_reshaped = ContractFilterForDepthwiseBackprop( - filter_shape, ctx->input_type(0), filter_backprop_reshaped, b); + filter_backprop = ContractFilterForDepthwiseBackprop( + ctx, filter_shape, ctx->input_type(0), filter_backprop, b); } - ctx->SetOutput(0, filter_backprop_reshaped); + ctx->SetOutput(0, filter_backprop); } protected: const int num_spatial_dims_; const bool depthwise_; + std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index fcef497e58..644abd5905 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -23,8 +23,8 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array<DataType, 4> kMatmulTypes = { - {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array<DataType, 5> kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; class MatMulOp : public XlaOpKernel { public: @@ -85,10 +85,7 @@ class SparseMatMulOp : public MatMulOp { ~SparseMatMulOp() override = default; }; -REGISTER_XLA_OP(Name("SparseMatMul") - .TypeConstraint("Ta", kFloatTypes) - .TypeConstraint("Tb", kFloatTypes), - SparseMatMulOp); +REGISTER_XLA_OP(Name("SparseMatMul"), SparseMatMulOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc new file mode 100644 index 0000000000..650f8c7dc8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <vector> + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// TODO(phawkins): implement double-sized windowed reductions in XLA and remove +// the type constraint. +constexpr std::array<DataType, 3> kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT}}; + +class ScanOp : public XlaOpKernel { + public: + ScanOp(OpKernelConstruction* ctx, bool sum) : XlaOpKernel(ctx), sum_(sum) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape tensor_axis_shape = ctx->InputShape(1); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis_shape), + errors::InvalidArgument("ScanOp: axis must be a scalar, not ", + tensor_axis_shape.DebugString())); + + int64 axis; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &axis)); + if (axis < 0) { + axis += input_shape.dims(); + } + OP_REQUIRES( + ctx, FastBoundsCheck(axis, input_shape.dims()), + errors::InvalidArgument("ScanOp: Expected scan axis in the range [", + -input_shape.dims(), ", ", input_shape.dims(), + "), but got ", axis)); + + DataType dtype = ctx->input_type(0); + + if (input_shape.num_elements() == 0) { + // Exit early if there is nothing to compute. + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + xla::ComputationBuilder* builder = ctx->builder(); + + std::vector<int64> window_strides(input_shape.dims(), 1); + std::vector<int64> window_dims(input_shape.dims(), 1); + window_dims[axis] = input_shape.dim_size(axis); + + std::vector<std::pair<int64, int64>> padding(input_shape.dims(), {0, 0}); + padding[axis].first = input_shape.dim_size(axis) - 1; + // In exclusive mode, add an extra padding element so there is a complete + // window of padding before the data starts. + if (exclusive_) { + ++padding[axis].first; + } + if (reverse_) { + std::swap(padding[axis].first, padding[axis].second); + } + + xla::ComputationDataHandle input = ctx->Input(0); + xla::ComputationDataHandle init; + const xla::Computation* reducer; + if (sum_) { + init = XlaHelpers::Zero(builder, dtype); + reducer = ctx->GetOrCreateAdd(dtype); + } else { + init = XlaHelpers::One(builder, dtype); + reducer = ctx->GetOrCreateMul(dtype); + } + auto output = builder->ReduceWindowWithGeneralPadding( + ctx->Input(0), init, *reducer, window_dims, window_strides, padding); + + // In exclusive mode, we have computed an extra element containing the sum + // of all the input elements. Slice off this extra "last" element. + if (exclusive_) { + if (reverse_) { + output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1, + 1, axis); + + } else { + output = + builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); + } + } + ctx->SetOutput(0, output); + } + + private: + const bool sum_; // True=cumulative sum. False=cumulative product. + bool reverse_; + bool exclusive_; +}; + +class CumsumOp : public ScanOp { + public: + explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {} +}; +REGISTER_XLA_OP(Name("Cumsum").TypeConstraint("T", kScanOpTypes), CumsumOp); + +class CumprodOp : public ScanOp { + public: + explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {} +}; +REGISTER_XLA_OP(Name("Cumprod").TypeConstraint("T", kScanOpTypes), CumprodOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 24a99f253d..06838d1625 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -25,58 +25,72 @@ limitations under the License. namespace tensorflow { namespace { +// Converts a TensorShape to a constant Tensor. +// +// The input TensorShape input_shape is used to populate the elements of +// shape_constant, which is modified in place. +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant) { + const int dims = input_shape.dims(); + if (shape_constant->dtype() == DT_INT32) { + auto vec = shape_constant->vec<int32>(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + if (!FastBoundsCheck(dim_size, std::numeric_limits<int32>::max())) { + return errors::InvalidArgument( + "Shape with out_type=int32 does not support tensors > int32max", + " but dim ", i, " is ", dim_size); + } + vec(i) = static_cast<int32>(dim_size); + } + } else { + auto vec = shape_constant->vec<int64>(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + vec(i) = dim_size; + } + } + return Status::OK(); +} + class ShapeOp : public XlaOpKernel { public: - explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); - const int rank = input_shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({rank})); - auto vec = shape_constant.vec<int32>(); - // TODO(dga): support int64. b/28119922. - for (int i = 0; i < rank; ++i) { - int64 dim_size = input_shape.dim_size(i); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but dim ", i, " is ", dim_size)); - vec(i) = static_cast<int32>(dim_size); - } - + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(0, shape_constant); } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("Shape"), ShapeOp); class ShapeNOp : public XlaOpKernel { public: - explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - const TensorShape shape = ctx->InputShape(i); - const int dims = shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({dims})); - auto vec = shape_constant.vec<int32>(); - - // TODO(dga): support int64. b/28119922. - for (int j = 0; j < dims; ++j) { - int64 dim_size = shape.dim_size(j); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but shape ", i, " dim ", j, " is ", - dim_size)); - vec(j) = static_cast<int32>(dim_size); - } - + const TensorShape input_shape = ctx->InputShape(i); + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(i, shape_constant); } } bool IsExpensive() override { return false; } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index b19ea22f50..2346c62ad1 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/no_op.h" namespace tensorflow { @@ -121,5 +122,31 @@ class ResourceGatherOp : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), ResourceGatherOp); +class VariableShapeOp : public XlaOpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + DataType dtype; + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &dtype, &shape)); + const int rank = shape.dims(); + Tensor shape_constant(DT_INT32, TensorShape({rank})); + auto vec = shape_constant.vec<int32>(); + // TODO(dga): support int64. b/28119922. + for (int i = 0; i < rank; ++i) { + int64 dim_size = shape.dim_size(i); + OP_REQUIRES( + ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), + errors::InvalidArgument("Shape does not support tensors > int32max", + " but dim ", i, " is ", dim_size)); + vec(i) = static_cast<int32>(dim_size); + } + + ctx->SetConstantOutput(0, shape_constant); + } +}; + +REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 7ffe0aa6df..943248aedb 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -40,6 +40,9 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, case xla::F16: return builder->ConstantR0<xla::half>(static_cast<xla::half>(value)); break; + case xla::BF16: + return builder->ConstantR0<bfloat16>(static_cast<bfloat16>(value)); + break; case xla::F32: return builder->ConstantR0<float>(static_cast<float>(value)); break; diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index b5c17c5273..43d0e17c2c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -28,9 +28,10 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, temps_(new void*[static_data.num_temps]), arg_names_(static_data.arg_names), result_names_(static_data.result_names), - program_shape_(static_data.program_shape) { + program_shape_(static_data.program_shape), + hlo_profile_printer_(static_data.hlo_profile_printer) { // Allocate arg and temp buffers. - if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( static_data.arg_sizes, static_data.num_args, args_, /*annotate_initialized=*/false); @@ -43,6 +44,15 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, if (static_data.requires_runtime_context) { args_[static_data.num_args - 1] = &context_; } + + // If Hlo profiling is enabled the generated code expects an appropriately + // sized buffer to be passed in as the last argument. If Hlo profiling is + // disabled the last function argument is still present in the function + // signature, but it is ignored by the generated code and we pass in null for + // it. + if (hlo_profiling_enabled()) { + profile_counters_ = new int64[static_data.profile_counters_size](); + } } XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { @@ -50,6 +60,7 @@ XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); delete[] args_; delete[] temps_; + delete[] profile_counters_; } namespace { diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index f49a788922..3c4314d498 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ -#include <functional> +#include <cassert> #include <string> #include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" @@ -27,6 +27,7 @@ limitations under the License. // never use this functionality. namespace xla { class ProgramShape; +class HloProfilePrinter; } namespace tensorflow { @@ -48,12 +49,10 @@ namespace tensorflow { class XlaCompiledCpuFunction { public: // Type of the raw function, produced by either JIT or AOT. - // - // TODO(toddw): Add support for hlo profiling, and replace std::function with - // a raw function pointer, for some codesize savings. - using RawFunction = std::function<void( - void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps)>; + using RawFunction = void (*)(void* result, + const xla::ExecutableRunOptions* run_options, + const void** args, void** temps, + int64* profile_counters); // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for @@ -81,21 +80,29 @@ class XlaCompiledCpuFunction { // [Optional] Arg and result shapes. const xla::ProgramShape* program_shape = nullptr; + + // [Optional] Profile printer. Null if profiling is disabled. + const xla::HloProfilePrinter* hlo_profile_printer = nullptr; + + // [Optional] The number of profile counters expected in the profile counter + // buffer by the generated code and hlo_profile_printer. 0 if profiling is + // disabled. + int64 profile_counters_size = 0; }; // AllocMode controls the buffer allocation mode. enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, + // Allocate all buffers - args, results, profile and temps. + ARGS_RESULTS_PROFILES_AND_TEMPS, - // Only allocate result and temp buffers. + // Only allocate result, profile and temp buffers. // Use set_arg_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, + RESULTS_PROFILES_AND_TEMPS_ONLY, }; XlaCompiledCpuFunction( const StaticData& static_data, - AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); virtual ~XlaCompiledCpuFunction(); XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; @@ -113,7 +120,7 @@ class XlaCompiledCpuFunction { context_.error = false; context_.error_msg.clear(); raw_function_(temps_[result_index_], &run_options_, - const_cast<const void**>(args_), temps_); + const_cast<const void**>(args_), temps_, profile_counters_); return !context_.error; } @@ -162,6 +169,16 @@ class XlaCompiledCpuFunction { return static_cast<const void* const*>(temps_[result_index_]); } + // Profile counters for this XLA computation. + // + // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in + // this case) these counters are non-null and are automatically populated by + // `Run`. The counters can then be pretty-printed using + // `hlo_profile_printer()`. + // + // When Hlo profiling is disabled, this accessor returns null. + const int64* profile_counters() const { return profile_counters_; } + // Returns the buffer for the positional result at the given `index`. void* result_data(size_t index) { return results()[index]; } const void* result_data(size_t index) const { return results()[index]; } @@ -195,6 +212,12 @@ class XlaCompiledCpuFunction { // program shape isn't available. const xla::ProgramShape* ProgramShape() const { return program_shape_; } + bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } + const xla::HloProfilePrinter& hlo_profile_printer() const { + assert(hlo_profiling_enabled()); + return *hlo_profile_printer_; + } + private: const RawFunction raw_function_; const size_t result_index_; @@ -208,6 +231,9 @@ class XlaCompiledCpuFunction { void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; + // Backing memory for profiling counters. + int64* profile_counters_ = nullptr; + // Options and context passed to the compiled function. xla::ExecutableRunOptions run_options_; tensorflow::XlaLocalRuntimeContext context_; @@ -216,6 +242,7 @@ class XlaCompiledCpuFunction { const char** arg_names_ = nullptr; const char** result_names_ = nullptr; const xla::ProgramShape* program_shape_ = nullptr; + const xla::HloProfilePrinter* hlo_profile_printer_ = nullptr; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 651bafd6c5..78e770c62b 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -178,6 +178,20 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } +const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { + return LookupOrCreate(type, &mul_func_, [this, type] { + const string type_string = DataTypeString(type); + VLOG(1) << "Building Mul() for " << type_string; + xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">"); + xla::PrimitiveType xla_type; + TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); + auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + b.Mul(x, y); + return b.Build().ConsumeValueOrDie(); + }); +} + const xla::Computation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, const std::function<xla::Computation()>& create) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index de8aafa362..55d2995987 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -102,6 +102,11 @@ class XlaContext : public ResourceBase { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Get an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -155,6 +160,9 @@ class XlaContext : public ResourceBase { // Cached computation to compute Sum of two elements, specialized by type. ComputationMap add_func_; + // Cached computation to compute Mul of two elements, specialized by type. + ComputationMap mul_func_; + // Cached computation to compute Sigmoid of an element, specialized by type. ComputationMap sigmoid_func_; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 9c3e15d2fa..ec9e535b70 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines helper routines for Tla JIT compilation. +// This file defines helper routines for XLA compilation. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" @@ -121,6 +121,8 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, DataType data_type) { switch (data_type) { + case DT_BFLOAT16: + return b->ConstantR0<bfloat16>(bfloat16::epsilon()); case DT_FLOAT: return b->ConstantR0<float>(std::numeric_limits<float>::epsilon()); case DT_DOUBLE: @@ -169,6 +171,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::S16: case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; + case xla::BF16: + literal = *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)); + break; case xla::F16: literal = *xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 1dd454ea8d..f727f20464 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -90,21 +90,6 @@ xla::StatusOr<size_t> ComputeResultIndex( return result_slice.index(); } -// Adapt ComputeFunctionType, which includes a final profile_counters arg, to -// RawFunction, which doesn't include that final arg. -// -// TODO(toddw): Change RawFunction and AOT to also pass the final -// profile_counters arg, and remove this adapter. -XlaCompiledCpuFunction::RawFunction RawFunctionAdapter( - xla::cpu::CpuExecutable::ComputeFunctionType compute_function) { - return [compute_function](void* result, - const xla::ExecutableRunOptions* run_options, - const void** args, void** temps) { - return compute_function(result, run_options, args, temps, - /*profile_counters=*/nullptr); - }; -} - // Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold // the actual strings in nonempty_names, and hold arrays of pointers in // name_ptrs, terminated by a nullptr entry. @@ -177,7 +162,7 @@ XlaJitCompiledCpuFunction::Compile( const xla::cpu::CpuExecutable* cpu_executable = static_cast<xla::cpu::CpuExecutable*>(executable->executable()); XlaCompiledCpuFunction::RawFunction raw_function = - RawFunctionAdapter(cpu_executable->compute_function()); + cpu_executable->compute_function(); const xla::BufferAssignment& buffer_assignment = cpu_executable->buffer_assignment(); @@ -211,6 +196,14 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.arg_names = jit->arg_names_.data(); jit->static_data_.result_names = jit->result_names_.data(); jit->static_data_.program_shape = jit->program_shape_.get(); + + if (cpu_executable->hlo_profiling_enabled()) { + jit->static_data_.hlo_profile_printer = + &cpu_executable->hlo_profile_printer(); + jit->static_data_.profile_counters_size = + cpu_executable->hlo_profile_printer().profile_counters_size(); + } + return std::move(jit_unique_ptr); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2b4cc9ba2d..79d501b511 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -417,6 +417,11 @@ const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( return XlaContext::Get(context_).GetOrCreateAdd(type); } +const xla::Computation* XlaOpKernelContext::GetOrCreateMul( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateMul(type); +} + XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} void XlaOpKernel::Compute(OpKernelContext* context) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 76bcf594e6..06845a674e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -210,6 +210,11 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Gets an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + private: OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index cce9310003..9febea8dcf 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -625,7 +625,41 @@ ComputationDataHandle ComputationBuilder::Lt( ComputationDataHandle ComputationBuilder::Dot( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return BinaryOp(BINOP_DOT, lhs, rhs, /*broadcast_dimensions=*/{}); + StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs); + if (!lhs_shape_or_status.ok()) { + NoteError(lhs_shape_or_status.status()); + return ComputationDataHandle(); + } + std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); + + DotDimensionNumbers dimension_numbers; + dimension_numbers.add_lhs_contracting_dimensions( + lhs_shape->dimensions_size() == 1 ? 0 : 1); + dimension_numbers.add_rhs_contracting_dimensions(0); + return DotGeneral(lhs, rhs, dimension_numbers); +} + +ComputationDataHandle ComputationBuilder::DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + DotRequest request; + *request.mutable_lhs() = lhs; + *request.mutable_rhs() = rhs; + *request.mutable_dimension_numbers() = dimension_numbers; + + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_dot_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making Dot request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); } ComputationDataHandle ComputationBuilder::Conv( diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index d2dbbbbebb..531b98cfb9 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -393,6 +393,11 @@ class ComputationBuilder { ComputationDataHandle Dot(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs); + // Enqueues a general dot instruction onto the computation. + ComputationDataHandle DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers); + // Default dimension numbers used for a 2D convolution. static constexpr int64 kConvBatchDimension = 0; static constexpr int64 kConvFeatureDimension = 1; diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 93d3cd425f..250df5f4d5 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -252,6 +252,10 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0<int32>(1); case S64: return *Literal::CreateR0<int64>(1); + case F16: + return *Literal::CreateR0<half>(static_cast<half>(1.0f)); + case BF16: + return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)); case F32: return *Literal::CreateR0<float>(1); case F64: @@ -263,8 +267,6 @@ Status Literal::Copy(const Literal& src_literal, case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - return *Literal::CreateR0<half>(static_cast<half>(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index f37e529caf..069d1b33ca 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -285,11 +285,11 @@ class Literal { std::unique_ptr<Literal> Relayout(const Layout& new_layout, const ShapeIndex& shape_index = {}) const; - // Creates a new literal by reshaping this literal to have 'shape'. Both the - // original shape and 'shape' must contain the same number of elements. The + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. StatusOr<std::unique_ptr<Literal>> Reshape( - tensorflow::gtl::ArraySlice<int64> shape) const; + tensorflow::gtl::ArraySlice<int64> dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 5bb81b80dd..bdf92eaed1 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -195,14 +195,26 @@ ReferenceUtil::ReduceWindow1DGeneric( const tensorflow::gtl::ArraySlice<int64>& window, const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { std::vector<int64> dim_lengths{static_cast<int64>(operand.size())}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + return ReduceWindow1DGeneric( + operand, init, reduce_func, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} +/* static */ std::unique_ptr<std::vector<float>> +ReferenceUtil::ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice<float>& operand, float init, + const std::function<float(float, float)>& reduce_func, + const tensorflow::gtl::ArraySlice<int64>& window, + const tensorflow::gtl::ArraySlice<int64>& stride, + const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) { + std::vector<int64> dim_lengths{static_cast<int64>(operand.size())}; std::vector<int64> window_counts(window.size(), 0); std::vector<int64> pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique<std::vector<float>>(window_counts[0]); diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 62d455d71a..58e1a84461 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -70,7 +70,7 @@ class ReferenceUtil { // dilation factors. static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensionsDilated( const Array4D<float>& lhs, const Array4D<float>& rhs, - std::pair<int64, int64> stride, Padding padding, + std::pair<int64, int64> kernel_stride, Padding padding, std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation, ConvolutionDimensionNumbers dnums); @@ -184,6 +184,12 @@ class ReferenceUtil { const std::function<float(float, float)>& reduce_func, const tensorflow::gtl::ArraySlice<int64>& window, const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding); + static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice<float>& operand, float init, + const std::function<float(float, float)>& reduce_func, + const tensorflow::gtl::ArraySlice<int64>& window, + const tensorflow::gtl::ArraySlice<int64>& stride, + const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding); static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric( const Array4D<float>& operand, float init, const std::function<float(float, float)>& reduce_func, diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 71491218aa..b1d0345e70 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -597,9 +597,13 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { - auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot, - rhs->mutable_operand(0), lhs->mutable_operand(0))); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), + rhs->mutable_operand(0), lhs->mutable_operand(0), + dot_dimension_numbers)); return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -1616,8 +1620,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); - auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( - dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto dot = computation_->AddInstruction(HloInstruction::CreateDot( + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 56dfb1cf0b..3d70505f6e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2138,8 +2138,10 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kDot, x, y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); std::unique_ptr<HloComputation> dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc index c6193b3fbb..2bbae25aee 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc @@ -149,6 +149,15 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( if (!rewrite_training_op_) { return Status::OK(); } + + std::vector<HloInstruction*> added_instructions; + auto add = [&](std::unique_ptr<HloInstruction> inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + // Expand batch norm training into smaller HLO ops. HloInstruction* operand = batch_norm->mutable_operand(0); const Shape operand_shape = operand->shape(); @@ -160,7 +169,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( Literal::CreateR0<float>(size_in_elements / feature_count); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = computation_->AddInstruction( + auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); HloInstruction* scale = batch_norm->mutable_operand(1); @@ -169,14 +178,12 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(zero_literal))); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); - + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector<int64> dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -185,105 +192,110 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( } } - auto scale_broadcasted = computation_->AddInstruction( + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // X^2. - auto operand_squared = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, operand, operand)); // Sum[X]. - auto sum = computation_->AddInstruction(HloInstruction::CreateReduce( - feature_shape, operand, zero, dimensions_without_feature, - add_reduce_computation)); + auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, + dimensions_without_feature, + add_reduce_computation)); // Sum[X^2]. - auto squared_sum = computation_->AddInstruction(HloInstruction::CreateReduce( + auto squared_sum = add(HloInstruction::CreateReduce( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); // Fuse two parallel reduces together to improve performance. - if (use_fusion_) { - auto tuple = computation_->AddInstruction( - HloInstruction::CreateTuple({sum, squared_sum})); + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); auto fused = computation_->CreateFusionInstruction( {tuple, sum, squared_sum, operand_squared}, HloInstruction::FusionKind::kInput); - sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - squared_sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + squared_sum = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // E[X]. - auto mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto square_mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); // E^2[X]. - auto mean_square = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean_square = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, mean, mean)); // Var[X]. - auto var = computation_->AddInstruction(HloInstruction::CreateBinary( + auto var = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); auto neg_half_literal = Literal::CreateR0(-0.5f); TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = computation_->AddInstruction( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kAdd, - scaled_normalized, offset_broadcasted)); - - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({shifted_normalized, mean, var}))); + auto shifted_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + + auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); + + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } @@ -317,52 +329,69 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( } } - auto scale_broadcasted = computation_->AddInstruction( + std::vector<HloInstruction*> added_instructions; + auto add = [&](std::unique_ptr<HloInstruction> inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); auto neg_half_literal = Literal::CreateR0(-0.5f); TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted); + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + if (batch_norm->has_sharding()) { + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + shifted_normalized->set_sharding(batch_norm->sharding()); + } TF_CHECK_OK( ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); return Status::OK(); @@ -385,6 +414,13 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( if (!rewrite_grad_op_) { return Status::OK(); } + std::vector<HloInstruction*> added_instructions; + auto add = [&](std::unique_ptr<HloInstruction> inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); const Shape activation_shape = activation->shape(); @@ -403,23 +439,22 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( Literal::CreateR0<float>(size_in_elements / feature_count); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = computation_->AddInstruction( + auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(zero_literal))); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto neg_half_literal = Literal::CreateR0(-0.5f); TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector<int64> dimensions_without_feature; @@ -429,126 +464,131 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( } } - auto scale_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, scale, {feature_index})); - auto variance_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, variance, {feature_index})); + auto scale_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, scale, {feature_index})); + auto variance_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, variance, {feature_index})); // E[X]. - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kAdd, variance, epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)), + neg_half)); + + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + feature_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, + epsilon)), + neg_half)); // X - E[X]. - auto activation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - activation, mean_broadcasted)); + auto activation_minus_mean = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); // Grad[Y] * (X - E[X]). - auto grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + auto grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, activation_minus_mean)); HloComputation* add_reduce_computation = GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = - computation_->AddInstruction(HloInstruction::CreateReduce( + add(HloInstruction::CreateReduce( feature_shape, grad_output_times_activiation_minus_mean, zero, dimensions_without_feature, add_reduce_computation)); // Grad[beta] = Sum(Grad[Y]). - auto grad_beta = computation_->AddInstruction(HloInstruction::CreateReduce( + auto grad_beta = add(HloInstruction::CreateReduce( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_) { - auto tuple = computation_->AddInstruction(HloInstruction::CreateTuple( + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple( {sum_grad_output_times_activiation_minus_mean, grad_beta})); auto fused = computation_->CreateFusionInstruction( {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, HloInstruction::FusionKind::kInput); - sum_grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum_grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - grad_beta = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + grad_beta = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = computation_->AddInstruction(HloInstruction::CreateBinary( + auto grad_scale = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); // I2 = Sum(Grad[Y]) - auto I2 = computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, grad_beta, {feature_index})); + auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, + {feature_index})); // I3 = Sum(Grad[Y] * (X - E[X])) - auto I3 = computation_->AddInstruction(HloInstruction::CreateBroadcast( + auto i3 = add(HloInstruction::CreateBroadcast( activation_shape, sum_grad_output_times_activiation_minus_mean, {feature_index})); // I4 = (X - E[X]) * I3 - auto I4 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, I3, activation_minus_mean)); + auto i4 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); // I5 = I4 / (Var[X] + epsilon) - auto I5 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, I4, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kAdd, variance_broadcasted, epsilon)))); + auto i5 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, i4, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)))); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted)); - scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, - scale_times_rsqrt_var_add_epsilon, elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, + elements_per_feature)); - auto I1 = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto i1 = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, elements_per_feature)); // I6 = I1 - I2 - I5 - auto I6 = computation_->AddInstruction(HloInstruction::CreateBinary( + auto i6 = add(HloInstruction::CreateBinary( activation_shape, HloOpcode::kSubtract, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, I1, I2)), - I5)); + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, + i1, i2)), + i5)); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, I6)); + auto grad_activation = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6)); + auto tuple = + HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), activation_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}))); + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8fba8ef5e5..09681b34e7 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1360,10 +1360,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { HloInstruction::CreateParameter(1, shape_3x4, "param_b")); auto param_c = builder.AddInstruction( HloInstruction::CreateParameter(2, shape_4x4, "param_c")); - auto dot_ab = builder.AddInstruction(HloInstruction::CreateBinary( - shape_2x4, HloOpcode::kDot, param_a, param_b)); - auto dot_bc = builder.AddInstruction(HloInstruction::CreateBinary( - shape_3x4, HloOpcode::kDot, param_b, param_c)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_ab = builder.AddInstruction( + HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); + auto dot_bc = builder.AddInstruction( + HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); builder.AddInstruction( HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); @@ -1708,9 +1711,8 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { BufferAssigner::Run( module.get(), xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence), - ByteSizeOf, - [](LogicalBuffer::Color) { return 1; }) - .ConsumeValueOrDie(); + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }) + .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index bbb42d494b..13825fe05b 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -167,11 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run( - module.get(), - xla::MakeUnique<SequentialHloOrdering>( - module.get(), sequence)) - .ConsumeValueOrDie(); + auto liveness = + BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -296,7 +295,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { module_sequence.emplace(computation, order); auto liveness = BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>( - module.get(), module_sequence)) + module.get(), module_sequence)) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -625,9 +624,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>( - module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. @@ -738,9 +736,8 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>( - module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index e1eed498f6..bf41d5ce07 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -250,6 +250,8 @@ cc_library( ":dot_op_emitter", ":external_constant_pool", ":ir_emission_utils", + ":ir_function", + ":parallel_loop_emitter", ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", @@ -281,6 +283,38 @@ cc_library( ) cc_library( + name = "ir_function", + srcs = ["ir_function.cc"], + hdrs = ["ir_function.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:vector_support_library", + "@llvm//:core", + ], +) + +cc_library( + name = "parallel_loop_emitter", + srcs = ["parallel_loop_emitter.cc"], + hdrs = ["parallel_loop_emitter.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( name = "dot_op_emitter", srcs = ["dot_op_emitter.cc"], hdrs = ["dot_op_emitter.h"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index addd7284c5..988f632748 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -528,9 +528,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( // uses data dependencies for determining order. TF_ASSIGN_OR_RETURN( std::unique_ptr<BufferAssignment> assignment, - BufferAssigner::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>(module.get()), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run( + module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -642,10 +642,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr<BufferAssignment> assignment, - BufferAssigner::Run( - module.get(), - xla::MakeUnique<SequentialHloOrdering>(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run(module.get(), + xla::MakeUnique<SequentialHloOrdering>( + module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -824,7 +824,8 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, TF_ASSIGN_OR_RETURN( std::unique_ptr<BufferAssignment> assignment, BufferAssigner::Run( - module, xla::MakeUnique<SequentialHloOrdering>(module, module_sequence), + module, + xla::MakeUnique<SequentialHloOrdering>(module, module_sequence), BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index b9e4d006d7..1c04c9835e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -31,6 +31,14 @@ namespace { using InstructionFusionTest = HloTestBase; +std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); +} + TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -40,8 +48,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, exp0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -59,8 +67,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -80,8 +88,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, bitcast0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -102,8 +110,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* reshape0 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1024, 256}), exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, reshape0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -121,8 +129,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 32 * 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -140,8 +148,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -162,8 +170,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, transpose1)); + builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 4c40dae512..4ccff756a3 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -518,14 +518,14 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { - if (dot_.shape().dimensions_size() != 2 || - ProfitableToImplementDotInUntiledLlvmIr(dot_) == - DotInLlvmIrProfitable::kYes) { + if (dot_.shape().dimensions_size() != 2) { return false; } - if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) && - !primitive_util::IsIntegralType(dot_.shape().element_type())) { + PrimitiveType primitive_type = dot_.shape().element_type(); + + if (!primitive_util::IsFloatingPointType(primitive_type) && + !primitive_util::IsIntegralType(primitive_type)) { return false; } @@ -575,30 +575,50 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { int64 tiling_factor = GetGemvTilingFactor(); CHECK_GT(tiling_factor, 0); + llvm::Value* result_op = target_array_.GetBasePointer(); + llvm::Value* lhs_op = + swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer(); + llvm::Value* rhs_op = + swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); + if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - ColumnMajorMatrixVectorProductEmitter emitter( - dot_.shape().element_type(), /*tile_rows=*/8, - /*tile_cols=*/tiling_factor, m, k, - swap_operands ? rhs_array_.GetBasePointer() - : lhs_array_.GetBasePointer(), - swap_operands ? lhs_array_.GetBasePointer() - : rhs_array_.GetBasePointer(), - target_array_.GetBasePointer(), ir_builder_); - emitter.Emit(); + int64 tile_rows = 8; + int64 tile_cols = tiling_factor; + + string kernel_name = tensorflow::strings::StrCat( + "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + ir_builder_, kernel_name, lhs_op, rhs_op, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* result_op) { + ColumnMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + result_op, ir_builder_); + emitter.Emit(); + }); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - RowMajorMatrixVectorProductEmitter emitter( - dot_.shape().element_type(), /*tile_rows=*/tiling_factor, - /*tile_cols=*/8, m, k, - swap_operands ? rhs_array_.GetBasePointer() - : lhs_array_.GetBasePointer(), - swap_operands ? lhs_array_.GetBasePointer() - : rhs_array_.GetBasePointer(), - target_array_.GetBasePointer(), ir_builder_); - emitter.Emit(); + int64 tile_rows = tiling_factor; + int64 tile_cols = 8; + + string kernel_name = tensorflow::strings::StrCat( + "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + ir_builder_, kernel_name, lhs_op, rhs_op, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* result_op) { + RowMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + result_op, ir_builder_); + emitter.Emit(); + }); } return true; @@ -977,9 +997,7 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { return false; } - if (ProfitableToImplementDotInUntiledLlvmIr(hlo) == - DotInLlvmIrProfitable::kYes || - ProfitableToImplementDotInTiledLlvmIr(hlo)) { + if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { return false; } @@ -1010,46 +1028,11 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { return false; } -DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( - const HloInstruction& dot) { - if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { - const Shape& result_shape = dot.shape(); - // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1 - // cache line size, so that we can have the reduction dimension of both the - // LHS and RHS matrices and still have some space "left over". This needs - // to be tuned further. - const int64 kReductionDimensionThresholdBytes = 8 * 1024; - const bool single_threaded_eigen = - !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); - - // This is the point at which it is better to call into Eigen and shard the - // dot across multiple worker threads. This is a rough estimate by running - // a matmult benchmark on my local machine, and it can be tuned further. - const int64 kMaxSingleThreadedFlops = 16 * 1024; - - const int64 M = result_shape.dimensions(0); - const int64 N = result_shape.dimensions(1); - const int64 K = dot.operand(1)->shape().dimensions(0); - const int64 primitive_type_size = - ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type()); - if (M == 1 && - K * primitive_type_size <= kReductionDimensionThresholdBytes && - (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) { - // Heuristics: - // - // - Look for a configuration where we will likely be able to keep LHS in - // L1 and do a cache-optimal traversal of RHS. - // - // - Bail out on matrices that are large enough that Eigen can profitably - // shard the computation across multiple cores. This only applies when - // multi-threading is enabled. - return LayoutUtil::IsMonotonicWithDim0Major( - dot.operand(1)->shape().layout()) - ? DotInLlvmIrProfitable::kWithColumnMajorRhs - : DotInLlvmIrProfitable::kYes; - } - } - return DotInLlvmIrProfitable::kNo; +// For vector-matrix dot products, it is always profitable to make the Rhs +// column major. +bool ProfitableToMakeDotRhsColumnMajor(const HloInstruction& hlo) { + return hlo.opcode() == HloOpcode::kDot && + hlo.shape().dimensions_size() == 2 && hlo.shape().dimensions(0) == 1; } bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index c9168ccc0f..2badb26f90 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -32,19 +32,9 @@ namespace cpu { bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); -enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; - -// Returns a value to indicate if (and under what conditions) will lowering -// |dot| as a untiled LLVM IR dot operation be profitable over calling into -// Eigen or emitting a tiled LLVM IR implementation. Possible return values -// are: -// -// * DotInLlvmIrProfitable::kYes - always profitable. -// * DotInLlvmIrProfitable::kNo - never profitable. -// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make -// the Rhs layout column major. -DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( - const HloInstruction& dot); +// Returns true to indicate that |hlo| is a dot, and that it is profitable to +// switch the layout of the |hlo|'s RHS operand to column major. +bool ProfitableToMakeDotRhsColumnMajor(const HloInstruction& hlo); // Returns true to indicate that we can generate a tiled LLVM IR implementation // for |dot|. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 502dd2e738..bb75d3f49e 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,6 +24,7 @@ limitations under the License. #include <utility> #include <vector> +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "llvm/CodeGen/TargetRegisterInfo.h" @@ -42,6 +43,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -124,131 +127,27 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation( } else { TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); } - InsertOrDie(&emitted_functions_, computation, compute_function_); - - return compute_function_; -} - -static llvm::Argument* GetArg(llvm::Function* f, int idx) { - llvm::Function::arg_iterator arg_iter = f->arg_begin(); - std::advance(arg_iter, idx); - return &*arg_iter; + llvm::Function* ir_function = compute_function_->function(); + InsertOrDie(&emitted_functions_, computation, ir_function); + // Delete 'compute_function', finalizing 'ir_function' and restoring caller + // IR insert point. + compute_function_.reset(); + return ir_function; } void IrEmitter::InitializeIrFunction(const string& function_name) { - // The function signature is: - // void function(i8* retval, i8* run_options, i8** params, i8** temps, - // i64* dynamic_loop_bounds, i64* prof_counters) - // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. - // - // Therefore, the generated function's signature (FunctionType) is statically - // determined - parameter unpacking is done in code generated into the - // function, rather than by a prologue dictated by the platform ABI. - // - // /--------------\ - // retval ----------> | return value | - // \--------------/ - // - // /-------------------------------\ - // run_options -----> | xla::ExecutableRunOptions | - // \-------------------------------/ - // - // /---------------------------------------------\ - // params --------> | param 0 | param 1 | ..... | param N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | param 0 | | param 1 | | param N-1 | - // \---------/ \---------/ \-----------/ - // - // /---------------------------------------------\ - // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | temp 0 | | temp 1 | | temp N-1 | - // \---------/ \---------/ \-----------/ - // - // /--------------------------------------------\ - // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| - // (elided for aot) \--------------------------------------------/ - // - // /---------------------------------------------\ - // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | - // (elided for aot) \---------------------------------------------/ - - // Even though the type of params and temps is void** in the host's view, in - // LLVM IR this is represented by i8*, similarly to void*. It's up to the code - // to use GEPs to unravel the indirection layers. - llvm::FunctionType* compute_function_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/GetComputeFunctionParams(), - /*isVarArg=*/false); - // Functions with local linkage get an inlining bonus. Because we know // a-priori that embedded functions (non-entry functions) will not have its // name resolved, give it local linkage. llvm::Function::LinkageTypes linkage = is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; - compute_function_ = - llvm::Function::Create(/*Ty=*/compute_function_type, - /*Linkage=*/linkage, - /*Name=*/AsStringRef(function_name), - /*Module=*/module_); - compute_function_->setCallingConv(llvm::CallingConv::C); - - // Set meaningful names for the function's arguments: useful for debugging. - llvm::Function::arg_iterator arg_iter = compute_function_->arg_begin(); - arg_iter->setName("retval"); - (++arg_iter)->setName("run_options"); - (++arg_iter)->setName("params"); - (++arg_iter)->setName("temps"); - if (num_dynamic_loop_bounds_ > 0) { - (++arg_iter)->setName("dynamic_loop_bounds"); - } - (++arg_iter)->setName("prof_counters"); - - // We know a-priori that the function arguments are guaranteed to point to - // disjoint objects. - llvm::Argument* retval = GetResultArgument(); - for (llvm::Argument& argument : compute_function_->args()) { - // However, the return buffer aliases the temporaries and thus cannot be - // marked noalias. - if (&argument == retval) { - continue; - } - compute_function_->addAttribute(argument.getArgNo() + 1, - llvm::Attribute::NoAlias); - } - - // Add the optize attribute to the function if optimizing for size. This - // controls internal behavior of some optimization passes (e.g. loop - // unrolling). - if (options::OptimizeForSizeRequested(hlo_module_config_)) { - compute_function_->addFnAttr(llvm::Attribute::OptimizeForSize); - } - - if (hlo_module_config_.debug_options().xla_enable_fast_math()) { - compute_function_->addFnAttr("unsafe-fp-math", "true"); - compute_function_->addFnAttr("no-infs-fp-math", "true"); - compute_function_->addFnAttr("no-nans-fp-math", "true"); - compute_function_->addFnAttr("no-signed-zeros-fp-math", "true"); - } - - ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( - /*Context=*/module_->getContext(), - /*Name=*/"entry", - /*Parent=*/compute_function_)); + // Create and initialize new IrFunction. + compute_function_.reset( + new IrFunction(function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_enable_fast_math(), + module_, &ir_builder_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -898,6 +797,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, /*supported_types=*/{F32, F64, C64})); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -1452,7 +1356,7 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { // // Where Param is the actual element type of the underlying buffer (for // example, float for an XLA F32 element type). - llvm::Argument* params = GetArg(compute_function_, 2); + llvm::Argument* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); llvm::LoadInst* param_address_untyped = @@ -1590,7 +1494,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( // Here we assume that the largest register is a vector register. int max_vector_register_size_in_bytes = target_machine_features_.largest_register_size_in_bytes( - compute_function_); + compute_function_->function()); int vector_register_size_in_elements = max_vector_register_size_in_bytes / @@ -1748,19 +1652,6 @@ void IrEmitter::EmitShardedVectorStore( } } -namespace { -// TODO(sanjoy): This is duplicated in tensorflow/core/lib/core/arena.cc. -// Extract out a common implementation to tensorflow/core/lib/math/math_util.h -uint32 GCD(uint32 x, uint32 y) { - while (y != 0) { - uint32 r = x % y; - x = y; - y = r; - } - return x; -} -} // namespace - StatusOr<bool> IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function, @@ -1783,9 +1674,9 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce( std::find(dimensions.begin(), dimensions.end(), arg->shape().layout().minor_to_major(0)) != dimensions.end(); - unsigned element_alignment = - GCD(ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), - MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); + unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>( + ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), + MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); if (is_reduction_over_minor_dimension) { // TODO(sanjoy): Implement vectorized reduction over the minor dimension. @@ -1995,7 +1886,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); // The code below emits a sequential loop nest. For the parallel backend, use - // EmitParallelTargetElementLoop() which respects dynamic loop bounds. + // ParallelLoopEmitter which respects dynamic loop bounds. if (ShouldEmitParallelLoopFor(*slice)) { return DefaultAction(slice); } @@ -2410,7 +2301,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Terminates the current block with a branch to a while header. llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), - compute_function_); + compute_function_->function()); ir_builder_.CreateBr(header_bb); ir_builder_.SetInsertPoint(header_bb); @@ -2427,7 +2318,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Branches to the body or to the while exit depending on the condition. llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "body")), - compute_function_); + compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb); @@ -2442,7 +2333,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { ir_builder_.CreateBr(header_bb); // Adds the exit block to the function and sets the insert point there. - compute_function_->getBasicBlockList().push_back(exit_bb); + compute_function_->function()->getBasicBlockList().push_back(exit_bb); ir_builder_.SetInsertPoint(exit_bb); return Status::OK(); @@ -2560,7 +2451,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, const llvm_ir::IrArray& source_array) { unsigned primitive_type_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - unsigned element_alignment = GCD( + unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>( primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); @@ -2642,7 +2533,6 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { if (prof_counter) { profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); } - ir_builder_.CreateRetVoid(); return Status::OK(); } @@ -2783,43 +2673,16 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { return llvm_ir::ShapeToIrType(shape, module_); } -std::vector<llvm::Type*> IrEmitter::GetComputeFunctionParams() { - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); - std::vector<llvm::Type*> compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (num_dynamic_loop_bounds_ > 0) { - compute_function_params.push_back(i64_ptr_type); - } - compute_function_params.push_back(i64_ptr_type); - return compute_function_params; -} - -llvm::Argument* IrEmitter::GetResultArgument() { - return GetArg(compute_function_, 0); -} - llvm::Argument* IrEmitter::GetProfileCountersArgument() { - const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; - return GetArg(compute_function_, arg_index); + return compute_function_->profile_counters_arg(); } llvm::Value* IrEmitter::GetTempBuffersArgument() { - return GetArg(compute_function_, 3); -} - -llvm::Value* IrEmitter::GetDynamicLoopBound(const int64 offset) { - CHECK_GT(num_dynamic_loop_bounds_, 0); - CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - llvm::Argument* loop_bounds_arg = GetArg(compute_function_, 4); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); - return ir_builder_.CreateLoad(ir_builder_.CreateGEP( - loop_bounds_arg, ir_builder_.getInt64(offset), AsStringRef(name))); + return compute_function_->temp_buffers_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { - return GetArg(compute_function_, 1); + return compute_function_->exec_run_options_arg(); } llvm::Value* IrEmitter::EmitTempBufferPointer( @@ -2965,7 +2828,8 @@ Status IrEmitter::EmitParallelForkJoin( HloInstruction* root = computation->root_instruction(); // Build ParallelForkJoin function type. - std::vector<llvm::Type*> compute_function_params = GetComputeFunctionParams(); + std::vector<llvm::Type*> compute_function_params = + compute_function_->GetComputeFunctionParams(); // Number of parallel compute functions. compute_function_params.push_back(ir_builder_.getInt32Ty()); // Array of partitions. There is an array element for each @@ -3066,7 +2930,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { if (op == op->parent()->root_instruction()) { // For the root node, we write directly to the output buffer of the // function. - llvm::Argument* retval = GetResultArgument(); + llvm::Argument* retval = compute_function_->result_arg(); if (!ShapeUtil::IsNil(target_shape)) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); @@ -3127,8 +2991,19 @@ Status IrEmitter::EmitTargetElementLoop( } else { if (ShouldEmitParallelLoopFor(*target_op)) { - TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( - target_shape, element_generator, IrName(target_op), &target_array)); + // Emit code to read dynamic loop bounds from compute function argument. + ParallelLoopEmitter::LoopBounds dynamic_loop_bounds( + num_dynamic_loop_bounds_); + for (int i = 0; i < num_dynamic_loop_bounds_; ++i) { + dynamic_loop_bounds[i].first = + compute_function_->GetDynamicLoopBound(i * 2 + 0); + dynamic_loop_bounds[i].second = + compute_function_->GetDynamicLoopBound(i * 2 + 1); + } + // Emit parallel loop with dynamic loop bounds for most-major dimensions. + TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, + &dynamic_loop_bounds, &ir_builder_) + .EmitLoop(IrName(target_op))); } else { TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) @@ -3138,60 +3013,6 @@ Status IrEmitter::EmitTargetElementLoop( return Status::OK(); } -Status IrEmitter::EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array) { - CHECK(!ShapeUtil::IsTuple(target_shape)); - CHECK(!ShapeUtil::IsScalar(target_shape)); - - // Emit code to read dynamic loop bounds from function argument 4. - std::vector<llvm::Value*> dynamic_loop_bounds(2 * num_dynamic_loop_bounds_); - for (int i = 0; i < 2 * num_dynamic_loop_bounds_; ++i) { - dynamic_loop_bounds[i] = GetDynamicLoopBound(i); - } - - llvm_ir::ForLoopNest loop_nest(loop_name, &ir_builder_); - const int64 num_dims = target_shape.dimensions_size(); - llvm_ir::IrArray::Index array_index(num_dims); - - // Add loops from outer-most to inner-most dimensions. - for (int i = target_shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - const int64 dimension = target_shape.layout().minor_to_major(i); - const int bounds_index = num_dims - 1 - i; - if (bounds_index < num_dynamic_loop_bounds_) { - // Emit dynamic loop bounds for this dimension. Dynamic loop bounds - // are read from ir function dynamic loop bounds argument. - llvm::Value* start_index = dynamic_loop_bounds[bounds_index * 2 + 0]; - llvm::Value* end_index = dynamic_loop_bounds[bounds_index * 2 + 1]; - - std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); - array_index[dimension] = loop->GetIndVarValue(); - } else { - // Emit static loop bounds for this dimension. - std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop( - /*start_index=*/0, - /*end_index=*/target_shape.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); - array_index[dimension] = loop->GetIndVarValue(); - } - } - // Point IR builder at inner loop BB. - SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - // Emit loop body. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - element_generator(array_index)); - target_array->EmitWriteArrayElement(array_index, target_element, - &ir_builder_); - // Point IR builder at outer loop exit BB. - SetToFirstInsertPoint(loop_nest.GetOuterLoopExitBasicBlock(), &ir_builder_); - - return Status::OK(); -} - Status IrEmitter::EmitMemcpy(const HloInstruction& source, const HloInstruction& destination) { llvm::Value* source_value = GetEmittedValueFor(&source); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 351c95278c..6b576d16bb 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include <stddef.h> #include <map> +#include <memory> #include <string> #include <unordered_map> #include <vector> @@ -30,6 +31,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -233,13 +235,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to get the IR type matching the given shape. llvm::Type* IrShapeType(const Shape& shape); - // Returns an array of compute function parameter types. - std::vector<llvm::Type*> GetComputeFunctionParams(); - - // Get the llvm::Value* that represents the "retval" argument of the - // computation function being emitted by this emitter. - llvm::Argument* GetResultArgument(); - // Get the llvm::Value* that represents the "prof_counters" argument of the // computation function being emitted by this emitter. llvm::Argument* GetProfileCountersArgument(); @@ -252,11 +247,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emit ir to read and return the ir value for the dynamic loop bound at - // 'offset' from the "dynamic_loop_bounds" argument of the computation - // function being emitted by this emitter. - llvm::Value* GetDynamicLoopBound(const int64 offset); - // Emits code that computes the address of the given temporary buffer to the // function. target_shape is the shape of this temporary buffer. // The returned Value's type is a pointer to element_type. @@ -346,15 +336,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, tensorflow::StringPiece desc, const llvm_ir::ElementGenerator& element_generator); - // Emit IR to perform a computation for every element in a partition/slice of - // 'target_shape'. The loop bounds for the outer-dimension partitions are - // passed into the compute function as a runtime argument (accessible from - // GetDynamicLoopBound). - Status EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array); - // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the // emitted_value_ table. @@ -476,8 +457,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory - // management rules, their memory is owned by the module. - llvm::Function* compute_function_; + // management rules, their memory is owned by the module (Note that IrFunction + // creates the encapsulated llvm::Function s.t. it is added to the llvm + // module's function list). + std::unique_ptr<IrFunction> compute_function_; llvm::IRBuilder<> ir_builder_; // Maps HLOs to their index into the profile counter array. @@ -490,7 +473,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm_ir::AliasAnalysis alias_analysis_; // The number of root instruction outer dimensions used in parallel loop - // emission (EmitParallelTargetElementLoop). + // emission (ParallelLoopEmitter). int64 num_dynamic_loop_bounds_ = 0; // Returns whether the given instruction should be emitted as a parallel loop. diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc new file mode 100644 index 0000000000..701bce2cbf --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -0,0 +1,195 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <iterator> + +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +namespace { +using llvm_ir::AsStringRef; +} // namespace + +namespace cpu { + +IrFunction::IrFunction(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, + int64 num_dynamic_loop_bounds) + : ir_builder_(ir_builder), + llvm_module_(llvm_module), + caller_insert_point_guard_(*ir_builder), + num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { + Initialize(function_name, linkage, optimize_for_size_requested, + enable_fast_math); +} + +IrFunction::~IrFunction() { + // Emit function return value. + ir_builder_->CreateRetVoid(); +} + +void IrFunction::Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math) { + // The function signature is: + // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // i64* dynamic_loop_bounds, i64* prof_counters) + // + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: address of an array with pointers to temporary buffers. + // + // Therefore, the generated function's signature (FunctionType) is statically + // determined - parameter unpacking is done in code generated into the + // function, rather than by a prologue dictated by the platform ABI. + // + // /--------------\ + // retval ----------> | return value | + // \--------------/ + // + // /-------------------------------\ + // run_options -----> | xla::ExecutableRunOptions | + // \-------------------------------/ + // + // /---------------------------------------------\ + // params --------> | param 0 | param 1 | ..... | param N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | param 0 | | param 1 | | param N-1 | + // \---------/ \---------/ \-----------/ + // + // /---------------------------------------------\ + // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | temp 0 | | temp 1 | | temp N-1 | + // \---------/ \---------/ \-----------/ + // + // /--------------------------------------------\ + // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| + // (elided for aot) \--------------------------------------------/ + // + // /---------------------------------------------\ + // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | + // \---------------------------------------------/ + + // Even though the type of params and temps is void** in the host's view, in + // LLVM IR this is represented by i8*, similarly to void*. It's up to the code + // to use GEPs to unravel the indirection layers. + llvm::FunctionType* function_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), + /*Params=*/GetComputeFunctionParams(), + /*isVarArg=*/false); + + // Functions with local linkage get an inlining bonus. Because we know + // a-priori that embedded functions (non-entry functions) will not have its + // name resolved, give it local linkage. + function_ = llvm::Function::Create(/*Ty=*/function_type, + /*Linkage=*/linkage, + /*N=*/AsStringRef(function_name), + /*M=*/llvm_module_); + function_->setCallingConv(llvm::CallingConv::C); + + // Set meaningful names for the function's arguments: useful for debugging. + llvm::Function::arg_iterator arg_iter = function_->arg_begin(); + arg_iter->setName("retval"); + result_arg_ = &*arg_iter; + (++arg_iter)->setName("run_options"); + exec_run_options_arg_ = &*arg_iter; + (++arg_iter)->setName("params"); + parameters_arg_ = &*arg_iter; + (++arg_iter)->setName("temps"); + temp_buffers_arg_ = &*arg_iter; + if (num_dynamic_loop_bounds_ > 0) { + (++arg_iter)->setName("dynamic_loop_bounds"); + dynamic_loop_bounds_arg_ = &*arg_iter; + } + (++arg_iter)->setName("prof_counters"); + profile_counters_arg_ = &*arg_iter; + + // We know a-priori that the function arguments are guaranteed to point to + // disjoint objects. + llvm::Argument* retval = result_arg(); + for (llvm::Argument& argument : function_->args()) { + // However, the return buffer aliases the temporaries and thus cannot be + // marked noalias. + if (&argument == retval) { + continue; + } + function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias); + } + + // Add the optize attribute to the function if optimizing for size. This + // controls internal behavior of some optimization passes (e.g. loop + // unrolling). + if (optimize_for_size_requested) { + function_->addFnAttr(llvm::Attribute::OptimizeForSize); + } + + if (enable_fast_math) { + function_->addFnAttr("unsafe-fp-math", "true"); + function_->addFnAttr("no-infs-fp-math", "true"); + function_->addFnAttr("no-nans-fp-math", "true"); + function_->addFnAttr("no-signed-zeros-fp-math", "true"); + } + + ir_builder_->SetInsertPoint(llvm::BasicBlock::Create( + /*Context=*/llvm_module_->getContext(), + /*Name=*/"entry", + /*Parent=*/function_)); +} + +std::vector<llvm::Type*> IrFunction::GetComputeFunctionParams() { + llvm::Type* i8_ptr_type = + llvm::Type::getInt8PtrTy(llvm_module_->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = + llvm::Type::getInt64PtrTy(llvm_module_->getContext()); + std::vector<llvm::Type*> compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds_ > 0) { + compute_function_params.push_back(i64_ptr_type); + } + compute_function_params.push_back(i64_ptr_type); + return compute_function_params; +} + +llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { + CHECK_GT(num_dynamic_loop_bounds_, 0); + CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); + string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + return ir_builder_->CreateLoad( + ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), + ir_builder_->getInt64(offset), AsStringRef(name))); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h new file mode 100644 index 0000000000..b7516b403e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ + +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace cpu { + +// IrFunction creates and encapsulates an llvm::Function, exposing methods to +// emitters for function and function argument access. +// The llvm::Function is created with the standard function signature +// used in the XLA CPU backend (see ir_function.cc for argument details). +// In addtion IrFunction saves the callers IR insert point during contruction, +// and restores it after desctruction. +// +// Example usage: +// +// // Create and initialize new IrFunction. +// std::unique_ptr<IrFunction> compute_function(new IrFunction(...)); +// // Emit IR for function body using IrFunction helper methods. +// ... +// // Store reference to llvm::Function for future invocation. +// ir_functions.push_back(compute_function.function()); +// // Delete IrFunction (finalizes IR function and restores caller insertion +// // point). +// compute_function.reset(); +// + +class IrFunction { + public: + IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds); + ~IrFunction(); + + // Returns an array of compute function parameter types. + std::vector<llvm::Type*> GetComputeFunctionParams(); + + // Emit ir to read and return the ir value for the dynamic loop bound at + // 'offset' from the "dynamic_loop_bounds" argument of this function. + llvm::Value* GetDynamicLoopBound(int64 offset); + + // Returns the encapculated llvm::Function. + llvm::Function* function() { return function_; } + + // Get the llvm::Value* that represents this functions "retval" argument. + llvm::Argument* result_arg() { return result_arg_; } + + // Get the xla::ExecutableRunOptions that represents this functions + // "run_options" argument. + llvm::Value* exec_run_options_arg() { return exec_run_options_arg_; } + + // Get the llvm::Argument that represents this functions parameters argument. + llvm::Argument* parameters_arg() { return parameters_arg_; } + + // Get the llvm::Value* that represents this functions "temps" argument. + llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } + + // Get the llvm::Value* that represents this functions "prof_counters" + // argument. + llvm::Argument* profile_counters_arg() { return profile_counters_arg_; } + + private: + // Initialize an llvm::Function with standard signature based on arguments. + void Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + bool optimize_for_size_requested, bool enable_fast_math); + + llvm::IRBuilder<>* ir_builder_; + llvm::Module* llvm_module_; + llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_; + + int64 num_dynamic_loop_bounds_ = 0; + // Encapsulated llvm::Function. + llvm::Function* function_; + // Function argument IR values. + llvm::Argument* result_arg_; + llvm::Value* exec_run_options_arg_; + llvm::Argument* parameters_arg_; + llvm::Value* temp_buffers_arg_; + llvm::Argument* dynamic_loop_bounds_arg_ = nullptr; + llvm::Argument* profile_counters_arg_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc index 3f2d101959..69466fd32e 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc @@ -52,8 +52,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( tensorflow::gtl::FlatMap<const HloInstruction*, bool> should_make_rhs_col_major_cache; auto should_make_rhs_col_major = [&](const HloInstruction& instruction) { - if (ProfitableToImplementDotInUntiledLlvmIr(instruction) != - DotInLlvmIrProfitable::kWithColumnMajorRhs) { + if (!ProfitableToMakeDotRhsColumnMajor(instruction)) { return false; } @@ -69,8 +68,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( bool result = std::all_of( rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) { - return ProfitableToImplementDotInUntiledLlvmIr(*user) == - DotInLlvmIrProfitable::kWithColumnMajorRhs && + return ProfitableToMakeDotRhsColumnMajor(*user) && user->operand(0) != rhs; }); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc new file mode 100644 index 0000000000..91e704e3d0 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { +namespace cpu { + +ParallelLoopEmitter::ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, const LoopBounds* dynamic_loop_bounds, + llvm::IRBuilder<>* ir_builder) + : LoopEmitter(target_element_generator, target_array, ir_builder), + dynamic_loop_bounds_(dynamic_loop_bounds) {} + +llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) { + CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!ShapeUtil::IsScalar(shape_)); + + llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); + const int64 num_dims = shape_.dimensions_size(); + llvm_ir::IrArray::Index array_index(num_dims); + + // Add loops from outer-most to inner-most dimensions. + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { + const int64 dimension = shape_.layout().minor_to_major(i); + const int bounds_index = num_dims - 1 - i; + if (bounds_index < dynamic_loop_bounds_->size()) { + // Emit dynamic loop bounds for this dimension. Dynamic loop bounds + // are read from ir function dynamic loop bounds argument. + llvm::Value* start_index = (*dynamic_loop_bounds_)[bounds_index].first; + llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; + + std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop( + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), + start_index, end_index); + array_index[dimension] = loop->GetIndVarValue(); + } else { + // Emit static loop bounds for this dimension. + std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/shape_.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + } + // Point IR builder at inner loop BB. + llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), + ir_builder_); + + // Set exit_bb_ to the exit block of the loop nest. + exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); + CHECK(exit_bb_ != nullptr); + + return array_index; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h new file mode 100644 index 0000000000..492d5953c4 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" + +namespace xla { +namespace cpu { + +// ParallelLoopEmitter emits a loop nest for the target array shape. +// The outer loop bounds of the loop nest are passed as ir values at runtime +// (specified in 'dynamic_loop_bounds'), and the inner loop bounds are static. +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. +// Then 'dynamic_loop_bounds' will contain the following ir values for +// the two most-major dimenions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] +// +// Code emitted by ParallelLoopEmitter will be called in a multi-threaded +// context where each thread will be assigned a different set of outer dimension +// partitions, and where all threads will collectively iterate over the +// entire target array shape. +// +// Outer dimension partitions can be generated using the ShapePartitionAssigner +// and ShapePartitionIterator utility classes from shape_partition.cc. +// +class ParallelLoopEmitter : public llvm_ir::LoopEmitter { + public: + using LoopBounds = std::vector<std::pair<llvm::Value*, llvm::Value*>>; + + // Constructs a ParallelLoopEmitter which uses 'target_element_generator' to + // generate elements, 'dynamic_loop_bounds' to set the loop bounds of the + // most-major dimensions, and 'target_array.' shape to set the static loop + // bounds for the most-minor dimensions. + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const LoopBounds* dynamic_loop_bounds, + llvm::IRBuilder<>* ir_builder); + + ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; + ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; + ~ParallelLoopEmitter() override = default; + + llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) override; + + private: + const LoopBounds* dynamic_loop_bounds_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc index 828ae675d7..f198c4c08e 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -55,19 +55,7 @@ MatchBackwardFilter(HloInstruction* conv) { // v v // Convolution // conv - // | - // v - // Transpose (optional if identity transposition) CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); - // If the forward convolution is followed by a transpose, we can fuse the - // transpose into the backward convolution as well. - HloInstruction* transpose = nullptr; - if (conv->user_count() == 1) { - HloInstruction* single_user = *conv->users().begin(); - if (single_user->opcode() == HloOpcode::kTranspose) { - transpose = single_user; - } - } // Step 2: match paddings and dimension numbers of the forward convolution. const ConvolutionDimensionNumbers& conv_dnums = @@ -75,6 +63,9 @@ MatchBackwardFilter(HloInstruction* conv) { auto input_batch_dim = conv_dnums.input_batch_dimension(); auto input_feature_dim = conv_dnums.input_feature_dimension(); auto input_spatial_dims = conv_dnums.input_spatial_dimensions(); + auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension(); + auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension(); + auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions(); auto output_batch_dim = conv_dnums.output_batch_dimension(); auto output_feature_dim = conv_dnums.output_feature_dimension(); auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); @@ -98,7 +89,8 @@ MatchBackwardFilter(HloInstruction* conv) { } // Padding high will be checked in Step 3. } - if (transpose == nullptr && !window_util::HasWindowDilation(conv->window())) { + if (input_batch_dim == output_batch_dim && + !window_util::HasWindowDilation(conv->window())) { VLOG(1) << conv->ToString() << " is a regular forward convolution. No need " "to fold it to a backward filter convolution."; @@ -169,53 +161,32 @@ MatchBackwardFilter(HloInstruction* conv) { } } - // To make future HLO passes easier, we canonicalize the fused expression by - // adding an identity transposition if it's omitted in the pattern. - if (transpose == nullptr) { - // Create an identity transposition with the same rank as the forward - // convolution. - HloComputation* parent_computation = conv->parent(); - std::vector<int64> transpose_dimensions(ShapeUtil::Rank(conv->shape())); - std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 0); - transpose = - parent_computation->AddInstruction(HloInstruction::CreateTranspose( - conv->shape(), conv, transpose_dimensions)); - TF_CHECK_OK(conv->ReplaceAllUsesWith(transpose)); - } - // Restore the dimension numbers of the backward convolution from the forward // convolution. The two activation dimensions are reversed (batch and // feature). ConvolutionDimensionNumbers backward_conv_dnums; backward_conv_dnums.set_input_batch_dimension(input_feature_dim); backward_conv_dnums.set_input_feature_dimension(input_batch_dim); - backward_conv_dnums.set_output_batch_dimension(output_feature_dim); - backward_conv_dnums.set_output_feature_dimension(output_batch_dim); for (int i = 0; i < input_spatial_dims.size(); ++i) { backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]); } - for (int i = 0; i < output_spatial_dims.size(); ++i) { - backward_conv_dnums.add_output_spatial_dimensions(output_spatial_dims[i]); + backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim); + backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim); + for (int i = 0; i < kernel_spatial_dims.size(); ++i) { + backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]); } // The dimension numbering of the output of the forward convolution (before // transposition) is the same as that of the activations (according to the // semantics of kConvolution). The batch dimension of the activations should // be treated as the input feature dimension, and the feature dimension should // be treated as the output feature. - // - // The output of the forward convolution needs to be transposed to fit into - // the dimension numbering of the weight gradients. This transposition maps - // dimension i to PositionInContainer(transpose->dimensions(), i). - backward_conv_dnums.set_kernel_input_feature_dimension( - PositionInContainer(transpose->dimensions(), output_batch_dim)); - backward_conv_dnums.set_kernel_output_feature_dimension( - PositionInContainer(transpose->dimensions(), output_feature_dim)); + backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim); + backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim); for (int i = 0; i < output_spatial_dims.size(); ++i) { - backward_conv_dnums.add_kernel_spatial_dimensions( - PositionInContainer(transpose->dimensions(), output_spatial_dims[i])); + backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); } - return std::make_tuple(true, std::vector<HloInstruction*>({transpose, conv}), + return std::make_tuple(true, std::vector<HloInstruction*>({conv}), backward_conv_window, backward_conv_dnums); } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc index 112c496e1f..34e6bdb117 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc @@ -46,18 +46,18 @@ class ConvolutionFoldingTest : public HloTestBase { // // TODO(jingyue): Add more tests on NCHW input order which TF also supports. tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); - tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3); tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); - tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0); tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); - tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2); - tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(2); tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension( 3); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2); + tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3); tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); @@ -86,7 +86,7 @@ class ConvolutionFoldingTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) { +TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -136,7 +136,7 @@ TEST_F(ConvolutionFoldingTest, auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + EXPECT_TRUE(FoldConvolution(module.get())); } // Extracted from block35 training. @@ -155,13 +155,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { conv_window.mutable_dimensions(i)->set_padding_low(1); conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -189,13 +185,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { conv_window.mutable_dimensions(i)->set_padding_high(-1); conv_window.mutable_dimensions(i)->set_window_dilation(2); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -222,13 +214,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { // Uneven padding: padding_low=0, padding_high=1 conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1b863c9e3c..abc739d181 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -246,6 +246,11 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 11290eda4f..c29fee0879 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -202,8 +202,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* transpose = backward_conv->fused_expression_root(); - HloInstruction* forward_conv = transpose->mutable_operand(0); + HloInstruction* forward_conv = backward_conv->fused_expression_root(); HloInstruction* input = backward_conv->mutable_operand(0); Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); @@ -269,19 +268,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), padded_input, output, new_forward_conv_window, forward_conv_dnums)); - HloInstruction* new_transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - ShapeInference::InferTransposeShape(new_forward_conv->shape(), - transpose->dimensions()) - .ConsumeValueOrDie(), - new_forward_conv, transpose->dimensions())); - - // Fuse the new forward convolution and the new transpose to the new backward - // convolution. + // Fuse the new forward convolution to the new backward convolution. HloInstruction* new_backward_conv = computation->CreateFusionInstructionForBackwardConvolution( - {new_transpose, new_forward_conv}, - HloInstruction::FusionKind::kConvBackwardFilter, + {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, new_backward_conv_window, backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 049e8d80d8..05017008e2 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -108,8 +108,11 @@ std::unique_ptr<HloModule> MakeBigGraph() { HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0)); auto clamp = builder.AddInstruction(HloInstruction::CreateTernary( vshape, HloOpcode::kClamp, copy, param_v1, param_v2)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(vshape, HloOpcode::kDot, clamp, param_v0)); + HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 17b926c874..387b649a73 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -259,8 +259,11 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -292,8 +295,11 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -327,10 +333,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -365,10 +374,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index e984bdb5f7..5d0cfba1fc 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -118,6 +118,9 @@ message HloInstructionProto { // Shape of outfeed request. xla.Shape outfeed_shape = 29; + + // Describes the dimension numbers used for a dot operation + xla.DotDimensionNumbers dot_dimension_numbers = 30; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c215cc48d6..014a851c96 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -176,10 +176,6 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { return false; } - if (instruction->HasSideEffect()) { - return false; - } - return true; } @@ -207,7 +203,8 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( worklist.pop(); if (removed.count(item) != 0 || item->user_count() != 0 || - item == root_instruction() || !IsRemovable(item)) { + item == root_instruction() || !IsRemovable(item) || + item->HasSideEffect()) { continue; } for (int i = 0; i < item->operand_count(); ++i) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 353b30bc69..ccedda2a03 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -313,11 +313,17 @@ class HloComputation { replacements, HloModule* module = nullptr, const string& suffix = "clone"); - // Returns true if the given instruction can be removed from the - // computation. Instructions such as parameters and send/receive instructions - // cannot be removed without violating invariants of the HLO computation or - // module with the exception of fusion computation. A parameter instruction - // is removable for a fusion computation. + // Returns true if the given instruction can be removed from the computation. + // Parameter instructions cannot be removed without violating invariants of + // the HLO computation with the exception of fusion computation. A parameter + // instruction is removable for a fusion computation. + // + // Note that IsRemovable() is a necessariy condition to remove an instruction + // rather than a sufficient condition. For example, instructions with + // side-effect (e.g., Send, Infeed) may be removed from a computation, but the + // transformation must guarantee the invariants relevant to the instructions + // still hold (e.g., Send and Recv must be removed together to make each + // channel complete). bool IsRemovable(const HloInstruction* instruction); // Returns true if this computation has a side effect. A computation has a diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 6fcc01dd64..0ed64e6779 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -201,10 +201,11 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1); - + int64 reduction_width = + lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); // First divide by reduction width before multiplying by rhs elements to avoid // overflow. int64 fma_count; diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 40e67c8780..1e5f0f797a 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -55,7 +55,8 @@ StatusOr<bool> HloDCE::Run(HloModule* module) { for (auto* instruction : computation->instructions()) { if (instruction->user_count() == 0 && live_instructions.count(instruction) == 0 && - computation->IsRemovable(instruction)) { + computation->IsRemovable(instruction) && + !instruction->HasSideEffect()) { dead_roots.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index d54b9a2708..5a56607a66 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -70,6 +70,26 @@ TEST_F(HloDceTest, NoDeadCode) { EXPECT_EQ(3, computation->instruction_count()); } +TEST_F(HloDceTest, InstructionsWithSideEffect) { + // Verify that side-effect instructions (Send in this test) are not removed. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + builder.AddInstruction( + HloInstruction::CreateSend(constant, /*channel_id=*/0)); + builder.AddInstruction(HloInstruction::CreateTuple({})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, computation->instruction_count()); + + HloDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); +} + TEST_F(HloDceTest, DeadParameters) { // Verify that dead parameters are not removed, but use of the dead parameters // are. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index b2c4351896..a5d39fe086 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -621,8 +621,11 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = @@ -664,8 +667,11 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = @@ -705,8 +711,11 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index ba75e2ef1b..0809fe780d 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -109,7 +109,8 @@ std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter( }; return MakeUnique<HloProfilePrinter>( - computation_infos, hlo_profile_index_map.computation_count(), deleter); + computation_infos, hlo_profile_index_map.computation_count(), + /*profile_counters_size=*/max_profile_index, deleter); } HloExecutionProfile::HloExecutionProfile( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c30c432654..b4bac18bcd 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -118,6 +118,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( MakeUnique<ConvolutionDimensionNumbers>( proto.convolution_dimension_numbers()); } + if (proto.has_dot_dimension_numbers()) { + instruction->dot_dimension_numbers_ = + MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers()); + } for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { instruction->slice_starts_.push_back(slice_dimensions.start()); @@ -332,6 +336,17 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, return instruction; } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->dot_dimension_numbers_ = + MakeUnique<DotDimensionNumbers>(dimension_numbers); + return instruction; +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, @@ -1086,7 +1101,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kLe: case HloOpcode::kLt: case HloOpcode::kNe: - case HloOpcode::kDot: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -1138,6 +1152,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, *convolution_dimension_numbers_); break; + case HloOpcode::kDot: + CHECK_EQ(new_operands.size(), 2); + clone = CreateDot(shape, new_operands[0], new_operands[1], + *dot_dimension_numbers_); + break; case HloOpcode::kCrossReplicaSum: CHECK_EQ(new_operands.size(), 1); clone = CreateCrossReplicaSum(shape, new_operands[0]); @@ -1509,7 +1528,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kFloor: @@ -1582,6 +1600,10 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals( convolution_dimension_numbers(), other.convolution_dimension_numbers()); + // Check dot dimension numbers. + case HloOpcode::kDot: + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + other.dot_dimension_numbers()); // Reduction results are determined by the reduction dimension and the // reduction computation. @@ -1990,6 +2012,9 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const { if (convolution_dimension_numbers_ != nullptr) { extra.push_back(ConvolutionDimensionNumbersToString()); } + if (dot_dimension_numbers_ != nullptr) { + extra.push_back(DotDimensionNumbersToString()); + } if (opcode() == HloOpcode::kWhile) { extra.push_back(StrCat("condition=%", while_condition()->name())); @@ -2086,6 +2111,9 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = *convolution_dimension_numbers_; } + if (dot_dimension_numbers_ != nullptr) { + *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; + } for (int i = 0; i < slice_starts_.size(); ++i) { auto* slice_dimension = proto.add_slice_dimensions(); slice_dimension->set_start(slice_starts_[i]); @@ -3051,6 +3079,30 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { return result; } +string HloInstruction::DotDimensionNumbersToString() const { + string result; + if (dot_dimension_numbers_ == nullptr) { + return result; + } + const DotDimensionNumbers& dnums = *dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result += "lhs_batch_dims="; + StrAppend(&result, Join(dnums.lhs_batch_dimensions(), ",")); + } + result += "lhs_contracting_dims="; + StrAppend(&result, Join(dnums.lhs_contracting_dimensions(), ",")); + + result += ","; + if (!dnums.rhs_batch_dimensions().empty()) { + result += "rhs_batch_dims="; + StrAppend(&result, Join(dnums.rhs_batch_dimensions(), ",")); + } + result += "rhs_contracting_dims="; + StrAppend(&result, Join(dnums.rhs_contracting_dimensions(), ",")); + + return result; +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index cda8b07c61..768c027a42 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -160,6 +160,12 @@ class HloInstruction { const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + static std::unique_ptr<HloInstruction> CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers); + // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. @@ -915,6 +921,15 @@ class HloInstruction { // Returns the dump string of the convolution dimension numbers. string ConvolutionDimensionNumbersToString() const; + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + CHECK(dot_dimension_numbers_ != nullptr); + return *dot_dimension_numbers_; + } + + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -1173,6 +1188,9 @@ class HloInstruction { // Describes the dimension numbers used for a convolution. std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; + // Describes the dimension numbers used for a dot. + std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; + // Describes the [begin, end) index range for a slice. std::vector<int64> slice_starts_; std::vector<int64> slice_limits_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 76b12fc8d3..11420cae63 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1068,8 +1068,11 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); @@ -1182,12 +1185,15 @@ TEST_F(HloInstructionTest, Stringification) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); EXPECT_EQ(dot->ToString(false, false), "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " - "%transpose)"); + "%transpose), lhs_contracting_dims=1,rhs_contracting_dims=0"); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index faaf73ea1c..6fe2134466 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -35,14 +35,15 @@ namespace xla { HloModule::HloModule(const string& name, const VersionedComputationHandle& entry_computation_handle, const HloModuleConfig& config) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), config_(config), has_entry_computation_handle_(true), entry_computation_handle_(entry_computation_handle) {} -HloModule::HloModule(const string& name) : name_(name) {} +HloModule::HloModule(const string& name) + : name_(NameUniquer::GetSanitizedName(name)) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) - : name_(name), config_(config) {} + : name_(NameUniquer::GetSanitizedName(name)), config_(config) {} HloComputation* HloModule::AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.h b/tensorflow/compiler/xla/service/hlo_profile_printer.h index 316753a82a..2f056490ae 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.h +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.h @@ -65,9 +65,11 @@ class HloProfilePrinter { HloProfilePrinter( HloComputationInfo* computation_infos, int64 computation_infos_size, + int64 profile_counters_size, std::function<void(HloComputationInfo*, int64)> deleter = nullptr) : computation_infos_(computation_infos), computation_infos_size_(computation_infos_size), + profile_counters_size_(profile_counters_size), deleter_(std::move(deleter)) {} HloProfilePrinter(HloProfilePrinter&& other) { @@ -79,10 +81,13 @@ class HloProfilePrinter { HloProfilePrinter(const HloProfilePrinter&) = delete; HloProfilePrinter& operator=(const HloProfilePrinter&) = delete; - // Convert the profile counter sequence `counters` to a human readable string + // Converts the profile counter sequence `counters` to a human readable string // representation. string ToString(const int64* counters, double clock_rate_ghz) const; + // Returns the size of the profile buffer expected by this printer. + int64 profile_counters_size() const { return profile_counters_size_; } + ~HloProfilePrinter(); private: @@ -90,6 +95,7 @@ class HloProfilePrinter { // is manifested as the deleter_ function. HloComputationInfo* computation_infos_ = nullptr; int64 computation_infos_size_ = 0; + int64 profile_counters_size_ = 0; std::function<void(HloComputationInfo*, int64)> deleter_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 017f996bc4..d09de7b528 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -566,7 +566,9 @@ Status MemoryUsageTracker::BeginInstruction(Item* item) { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -603,8 +605,9 @@ Status MemoryUsageTracker::EndInstruction() { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); - + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index d1adec31c2..447c244666 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -246,7 +246,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, // The tile rank must be the same as the input rank. if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank"); + "Tile rank is different to the input rank. sharding=", ToString(), + ", input_shape=", ShapeUtil::HumanString(shape)); } // The tile shape must not be the same as the input shape without maximal_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 15188c4057..ea7775b18a 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -75,7 +75,11 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleDot(HloInstruction* dot) override { - return CheckBinaryShape(dot); + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferDotOpShape( + dot->operand(0)->shape(), dot->operand(1)->shape(), + dot->dot_dimension_numbers())); + return CheckShape(dot, expected); } Status HandleConvolution(HloInstruction* convolution) override { @@ -143,9 +147,13 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleBitcast(HloInstruction* bitcast) override { - // Bitcasts can be any shape, as long as the size matches the operand size. - TF_RET_CHECK(shape_size_fn_(bitcast->shape()) == - shape_size_fn_(bitcast->operand(0)->shape())); + // Bitcasts that are not the root of a computation can be any shape. + // Bitcasts that are the root of a computation must have the same shape + // byte size as their operand. + if (bitcast->parent()->root_instruction() == bitcast) { + TF_RET_CHECK(shape_size_fn_(bitcast->shape()) == + shape_size_fn_(bitcast->operand(0)->shape())); + } return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index 476e86fa72..2c2a02f637 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -277,8 +277,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto b = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}}))); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); @@ -312,8 +315,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); + HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 29cc0f81bd..d951a37d5d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { void KernelSupportLibrary::For( @@ -62,4 +63,47 @@ void KernelSupportLibrary::If( false_block_generator(); llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); } + +void KernelSupportLibrary::EmitAndCallOutlinedKernel( + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + KernelSupportLibrary::ArgumentVector arguments, + const std::function<void(KernelSupportLibrary::ArgumentVector)>& + kernel_body_generator) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + llvm::Function* function = + module->getFunction(llvm_ir::AsStringRef(kernel_name)); + if (!function) { + VLOG(2) << "Generating kernel for " << kernel_name; + std::vector<llvm::Type*> arg_types; + std::transform(arguments.begin(), arguments.end(), + std::back_inserter(arg_types), + [](llvm::Value* arg) { return arg->getType(); }); + + auto* function_type = llvm::FunctionType::get( + ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false); + + function = llvm::Function::Create( + function_type, llvm::GlobalValue::InternalLinkage, + llvm_ir::AsStringRef(kernel_name), module); + + llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder); + + auto* entry_bb = + llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function); + auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(), + /*retVal=*/nullptr, entry_bb); + // Set the insert point to before return_inst. + ir_builder->SetInsertPoint(return_inst); + + std::vector<llvm::Value*> arg_values; + std::transform(function->arg_begin(), function->arg_end(), + std::back_inserter(arg_values), std::addressof<llvm::Value>); + kernel_body_generator(arg_values); + } else { + VLOG(3) << "Re-using kernel for " << kernel_name; + } + + ir_builder->CreateCall(function, llvm_ir::AsArrayRef(arguments)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 9bafb7b577..997b84bb27 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -118,6 +118,38 @@ class KernelSupportLibrary { const std::function<void()>& true_block_generator, const std::function<void()>& false_block_generator = []() {}); + using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>; + + // Generates the following control flow structure: + // + // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) { + // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`}); + // } + // + // ... + // call @`kernel_name`(arguments[0], arguments[1] ...) + // ... + // + // If a function called `kernel_name` is already present in the module then + // that function is re-used. In that sense we're using the llvm::Module as a + // cache of outlined kernels, keyed by function name. + static void EmitAndCallOutlinedKernel( + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + ArgumentVector arguments, + const std::function<void(ArgumentVector)>& kernel_body_generator); + + // Thin wrapper around the more general EmitAndCallOutlinedKernel above. + static void EmitAndCallOutlinedKernel( + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>& + kernel_body_generator) { + EmitAndCallOutlinedKernel( + ir_builder, kernel_name, {arg0, arg1, arg2}, [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], args[2]); + }); + } + private: llvm::IRBuilder<>* ir_builder_; bool prevent_unrolling_; diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index a0d08c288d..7d8c05fffa 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -17,12 +17,44 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + +bool IsAllowed(char character) { + auto c = static_cast<unsigned char>(character); + return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; +} + +} // namespace + +NameUniquer::NameUniquer(const string& separator) { + CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) + << "separator should comprises allowed characters only"; + separator_ = separator; +} + +/*static*/ string NameUniquer::GetSanitizedName(const string& name) { + string result = name; + CHECK(!result.empty()) << "name should not be empty"; + char c = static_cast<unsigned char>(result[0]); + if (!isalpha(c) && c != '_') { + result[0] = '_'; + } + for (int i = 1; i < result.length(); i++) { + if (!IsAllowed(result[i])) { + result[i] = '_'; + } + } + return result; +} + string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { string root = prefix.empty() ? "name" : prefix.ToString(); + root = GetSanitizedName(root); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index ed379b5225..4139c2700b 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -28,14 +28,21 @@ namespace xla { // Simple stateful class that helps generate "unique" names. To use it, simply // call GetUniqueName as many times as needed. The names returned by // GetUniqueName are guaranteed to be distinct for this instance of the class. +// Note that the names will be sanitized to match regexp +// "[a-zA-Z_][a-zA-Z0-9_.-]*". class NameUniquer { public: - explicit NameUniquer(const string& separator = "__") - : separator_(separator) {} + // The separator must contain allowed characters only: "[a-zA-Z0-9_.-]". + explicit NameUniquer(const string& separator = "__"); - // Get a unique name in a string, with an optional prefix for convenience. + // Get a sanitized unique name in a string, with an optional prefix for + // convenience. string GetUniqueName(tensorflow::StringPiece prefix = ""); + // Sanitizes and returns the name. Unallowed characters will be replaced with + // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + static string GetSanitizedName(const string& name); + private: // The string to use to separate the prefix of the name from the uniquing // integer value. diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 9f0747a6e2..4258cf1687 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -60,12 +60,30 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); +} + +TEST_F(NameUniquerTest, Sanitize) { + NameUniquer uniquer("_"); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo_1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); + EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54")); + EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1")); + EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); + + // Invalid characters will be replaced with '_'. + EXPECT_EQ("bar", uniquer.GetUniqueName("bar<-1000")); + EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); + EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); // Separator is only recognized in the middle of the prefix. - EXPECT_EQ(".10", uniquer.GetUniqueName(".10")); - EXPECT_EQ(".10.1", uniquer.GetUniqueName(".10")); - EXPECT_EQ("foobar.", uniquer.GetUniqueName("foobar.")); - EXPECT_EQ("foobar..1", uniquer.GetUniqueName("foobar.")); + EXPECT_EQ("_10", uniquer.GetUniqueName( + ".10")); // the leading '.' is replaced with '_'. + EXPECT_EQ("_10_1", uniquer.GetUniqueName(".10")); + EXPECT_EQ("_10_2", uniquer.GetUniqueName("_10")); + EXPECT_EQ("foobar_", uniquer.GetUniqueName("foobar_")); + EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_")); } } // namespace diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index d997cab83f..fa62080be4 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1381,6 +1381,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddCustomCallInstruction(arg->custom_call_request()); break; + case OpRequest::kDotRequest: + handle_status = computation->AddDotInstruction(arg->dot_request()); + break; case OpRequest::kDynamicSliceRequest: handle_status = computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 3df1911d07..7178eb40dd 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -90,8 +91,6 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_ATAN2; case HloOpcode::kComplex: return BINOP_COMPLEX; - case HloOpcode::kDot: - return BINOP_DOT; case HloOpcode::kMultiply: return BINOP_MUL; case HloOpcode::kAdd: @@ -549,8 +548,98 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions); } -/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(const Shape& lhs, - const Shape& rhs) { +// Current DotDimensionNumbers Requirements: +// +// Contracting Dimensions: +// *) Exactly one contracting dimension on both lhs and rhs. +// *) Contracting dimension size must be the same on both lhs and rhs. +// *) Contracting dimension numbers do not need to be the same (i.e. transposes +// are passed on to emitter implementations). +// +// Batch Dimensions: +// *) Same number of batch dimensions on both lhs and rhs. +// *) Same batch dimension numbers (and sizes) on both lhs and rhs. +// +// Non-Contracting-Non-Batch Dimensions: +// *) Can be 0 (matrix-vector) or 1 (matrix-matrix). +// + +namespace { + +Status ValidateDotDimensionNumbers( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { + // Check that dimension numbers are in range. + auto dims_in_range = + [](const int64 rank, tensorflow::gtl::ArraySlice<int64> contracting_dims, + tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool { + auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + in_range) && + std::all_of(batch_dims.begin(), batch_dims.end(), in_range); + }; + + tensorflow::gtl::ArraySlice<int64> lhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice<int64> rhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice<int64> lhs_batch_dimensions = + AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); + tensorflow::gtl::ArraySlice<int64> rhs_batch_dimensions = + AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); + + if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, + lhs_batch_dimensions) || + !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, + rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is out of range in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that dimension numbers are unique. + auto dims_unique = [](tensorflow::gtl::ArraySlice<int64> contracting_dims, + tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool { + tensorflow::gtl::FlatSet<int64> dim_set; + auto is_unique = [&dim_set](int64 i) -> bool { + return dim_set.insert(i).second; + }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + is_unique) && + std::all_of(batch_dims.begin(), batch_dims.end(), is_unique); + }; + + if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || + !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is not unique in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. + const int64 lhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(lhs) - + dimension_numbers.lhs_contracting_dimensions_size() - + dimension_numbers.lhs_batch_dimensions_size(); + const int64 rhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(rhs) - + dimension_numbers.rhs_contracting_dimensions_size() - + dimension_numbers.rhs_batch_dimensions_size(); + if (lhs_non_contracting_non_batch_dims < 0 || + lhs_non_contracting_non_batch_dims > 1 || + rhs_non_contracting_non_batch_dims < 0 || + rhs_non_contracting_non_batch_dims > 1) { + return InvalidArgument( + "batch and contracting dimension number mismatch " + "with rank "); + } + + return Status::OK(); +} + +} // namespace + +/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); @@ -570,37 +659,62 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return fail("element types do not match"); } - if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 || - ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) { - return fail("dot only supports rank 1 or 2"); + if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { + return fail("dot only supports rank 1 or above."); } - // Determine the index of the contracted dimensions for input tensors. - // dimensions -1 of lhs and dimension 0 of rhs are contracted. - int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1); - int64 rhs_contracted_dimension = 0; + // Validate basic properties of dot dimension numbers. + TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); + + // Check that there is only one contracting dimension for both lhs and rhs. + if (dimension_numbers.lhs_contracting_dimensions_size() != + dimension_numbers.rhs_contracting_dimensions_size() || + dimension_numbers.lhs_contracting_dimensions_size() != 1) { + return fail("must specify one contracting dimension for both lhs and rhs."); + } - // Check if the contracted dimension sizes are the same. - if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) && - rhs_contracted_dimension < ShapeUtil::Rank(rhs)) && - lhs.dimensions(lhs_contracted_dimension) != - rhs.dimensions(rhs_contracted_dimension)) { - return fail("contracted dimensions mismatch"); + // Check that contracting dimension sizes match. + const int64 lhs_contracting_dimension = + dimension_numbers.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = + dimension_numbers.rhs_contracting_dimensions(0); + if (lhs.dimensions(lhs_contracting_dimension) != + rhs.dimensions(rhs_contracting_dimension)) { + return fail("contracting dimension sizes do not match."); + } + + // Check that number of batch dimensions match. + if (dimension_numbers.lhs_batch_dimensions_size() != + dimension_numbers.rhs_batch_dimensions_size()) { + return fail("must the same number of batch dimensions for lhs and rhs."); + } + + // Check that batch dimension numbers and sizes match. + for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { + if (dimension_numbers.lhs_batch_dimensions(i) != + dimension_numbers.rhs_batch_dimensions(i) || + lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { + return fail("batch dimension numbers and sizes must match for lhs/rhs."); + } } // The ranks of lhs and rhs are decremented by 1 respectively due to the // contraction, and added for the rank of the result. When an input tensor is // a scalar, its contribution to the rank of the result is 0. // Generate the result dimensions in order, rhs dimensions followed by lhs - // dimensions except the contracted dimensions. + // dimensions except the contracted and batch dimensions. std::vector<int64> dimensions; + std::unordered_set<int64> rhs_batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { - if (i != lhs_contracted_dimension) { + if (i != lhs_contracting_dimension) { dimensions.push_back(lhs.dimensions(i)); } } for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { - if (i != rhs_contracted_dimension) { + if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) { dimensions.push_back(rhs.dimensions(i)); } } @@ -816,8 +930,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( rhs, tensorflow::strings::StrCat("rhs of binary operation ", BinaryOperation_Name(operation)))); switch (operation) { - case BINOP_DOT: - return InferDotOpShape(lhs, rhs); case BINOP_MAX: case BINOP_MIN: case BINOP_SUB: diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 0aadb98a40..382c4f8abc 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -229,11 +229,13 @@ class ShapeInference { tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, const ProgramShape& to_apply); - private: // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. - static StatusOr<Shape> InferDotOpShape(const Shape& lhs, const Shape& rhs); + static StatusOr<Shape> InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. // Note: By "element-wise" we mean operations that look at a single element in diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index be93c879c0..6e53d2d609 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -898,8 +898,11 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { // scalar <dot> vector: error TEST_F(ShapeInferenceTest, ScalarDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); + ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("dot only supports rank")); @@ -907,61 +910,199 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { // 3D <dot> 2D: error TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("batch and contracting dimension number mismatch")); } // vector <dot> vector -> scalar TEST_F(ShapeInferenceTest, VectorDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); auto inferred_status_mismatch = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix <dot> vector -> vector TEST_F(ShapeInferenceTest, MatrixDotVector) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // vector <dot> matrix -> vector TEST_F(ShapeInferenceTest, VectorDotMatrix) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix <dot> matrix -> matrix TEST_F(ShapeInferenceTest, MatrixDotMatrix) { - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status_match = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) << "inferred: " << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } +// BatchMatMul with two batch dimensions and one contracting dimension. +TEST_F(ShapeInferenceTest, DotGeneral) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + + dot_dnums.add_rhs_contracting_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status_match = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE( + ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape)) + << "inferred: " + << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) + << " expected: " << ShapeUtil::HumanString(output_shape); +} + +// BatchMatMul with two contracting dimensions fails. +TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("must specify one contracting dimension for both " + "lhs and rhs")); +} + +// BatchMatMul with different batch dimension sizes fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers and sizes must match")); +} + +// BatchMatMul with different batch dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers and sizes must match")); +} + +// BatchMatMul with out-of-range dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is out of range")); +} + +// BatchMatMul with non-unique dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is not unique")); +} + TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { // Test variations of broadcasting a vector for a binary add with a // matrix. diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index fb55d4e543..42b616f4c3 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -102,6 +102,10 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; auto& operand_indices = pair.second; + if (operand_indices.empty()) { + return false; + } + const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); ConvolutionDimensionNumbers new_dnums = dnums; @@ -121,8 +125,9 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { transpose_dimensions[dnums.input_batch_dimension()]); new_dnums.set_input_feature_dimension( transpose_dimensions[dnums.input_feature_dimension()]); - for (const auto& spatial_dimension : dnums.input_spatial_dimensions()) { - CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]); + for (auto& input_spatial_dimension : + *new_dnums.mutable_input_spatial_dimensions()) { + input_spatial_dimension = transpose_dimensions[input_spatial_dimension]; } new_lhs = &transpose_operand; } else { diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 6ac32e88f1..caa1a111ad 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -64,9 +64,12 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -104,9 +107,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 3}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/transpose0, /*rhs=*/transpose1)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1, 3}), + /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -169,9 +175,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -376,5 +385,69 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); } +// Test that a transpose of every dimension in the activations gets folded into +// convolution. +TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_x = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( + transpose_x->shape(), y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set<HloInstruction*> instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ( + dnums.input_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); + EXPECT_EQ( + dnums.input_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 4e90491b55..6d0d367981 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -88,8 +88,6 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kAtan2; case BINOP_COMPLEX: return HloOpcode::kComplex; - case BINOP_DOT: - return HloOpcode::kDot; case BINOP_MUL: return HloOpcode::kMultiply; case BINOP_ADD: @@ -1207,6 +1205,33 @@ StatusOr<ComputationDataHandle> UserComputation::AddCustomCallInstruction( return handle; } +StatusOr<ComputationDataHandle> UserComputation::AddDotInstruction( + const DotRequest& dot_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, + LookUpRequest(dot_request.lhs())); + TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, + LookUpRequest(dot_request.rhs())); + + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( + lhs->output_shape(), rhs->output_shape(), + dot_request.dimension_numbers())); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_dot_request() = dot_request; + + VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << dot_request.ShortDebugString(); + return handle; +} + StatusOr<ComputationDataHandle> UserComputation::AddUnaryInstruction( const UnaryOpRequest& unary_request) { tensorflow::mutex_lock lock(mutex_); @@ -1629,6 +1654,15 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + PureFunctionalVisitor(session_computation, dot_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, dot_request.rhs(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::kSendRequest: { *is_functional = false; break; @@ -2453,6 +2487,13 @@ static void ForEachOperand( break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + apply(dot_request.rhs()); + apply(dot_request.lhs()); + break; + } + case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); @@ -2732,6 +2773,15 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + HloInstruction* lhs = lookup_instruction(dot_request.lhs()); + HloInstruction* rhs = lookup_instruction(dot_request.rhs()); + hlo_instruction = add_instruction(HloInstruction::CreateDot( + request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); + break; + } + case OpRequest::kCrossReplicaSumRequest: { const CrossReplicaSumRequest& cross_replica_sum_request = request.request().cross_replica_sum_request(); @@ -3151,8 +3201,7 @@ void ComputationLowerer::Visit( lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - binary_op_request.binop() != BINOP_DOT) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 317c631dca..b6686c3f1a 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -153,6 +153,10 @@ class UserComputation { StatusOr<ComputationDataHandle> AddCustomCallInstruction( const CustomCallRequest& custom_call_request); + // Enqueues a dot instruction onto this user computation. + StatusOr<ComputationDataHandle> AddDotInstruction( + const DotRequest& dot_request); + // Enqueues a broadcast instruction onto this user computation. StatusOr<ComputationDataHandle> AddBroadcastInstruction( const BroadcastRequest& broadcast_request); diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 5afaf226ae..e45673300b 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -334,50 +334,5 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { operands[1]->opcode() == HloOpcode::kBroadcast); } -TEST_F(UserComputationTest, SkipDotInEliminatingImplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // %a = Param({1, 3}); - // %b = Param({3, 1}); - // %dot = Dot(%a, %b); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {3, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest dot; - dot.set_binop(BINOP_DOT); - *dot.mutable_lhs() = a_handle; - *dot.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(dot).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<HloComputation> hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - EXPECT_EQ(3, hlo_computation->instruction_count()); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index b38ee907d7..b2fd64a4d9 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -289,7 +289,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { // Don't try this transformation if the while loop isn't removable, since if // it succeeds ultimately we're going to have to replace the old while loop // with a new one. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Can't remove dead parameters from non-removable while op."; return false; } @@ -558,7 +558,7 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) { // the loop aren't removed, just cloned and added back to the loop. // Nevertheless our infrastructure sees loop simplification as removal of // these nodes and currently doesn't allow it. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Not attempting to remove while loop it is not removable: " << while_op->ToShortString(); return false; diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 5bf9842a6c..789eba5780 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -32,13 +32,13 @@ tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { return tensorflow::Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* other_shape) const { - if (!ShapeUtil::Compatible(*other_shape, shape_)) { +tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { + if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(*other_shape).c_str(), + ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } - *other_shape = shape_; + *to_shape = shape_; return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 92564660f2..4c83750f3e 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -38,18 +38,19 @@ class ShapeLayout { explicit ShapeLayout(const Shape& shape) : shape_(shape) {} // Assigns the layouts in this ShapeLayout to the Layout fields of the given - // shape. 'shape' and the shape of the ShapeLayout object must be compatible. - tensorflow::Status AssignLayoutToShape(Shape* shape) const; + // shape. 'to_shape' and the shape of the ShapeLayout object must be + // compatible. + tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible // with the ShapeLayout's shape, then false is returned. bool MatchesLayoutInShape(const Shape& shape) const; - // Copies the layout from the given shape into this ShapeLayout. 'shape' must - // be compatible with the ShapeLayout's shape, and 'shape' must have a layout - // (LayoutUtil::HasLayout). - tensorflow::Status CopyLayoutFromShape(const Shape& shape); + // Copies the layout from the given shape into this ShapeLayout. 'other_shape' + // must be compatible with the ShapeLayout's shape, and 'other_shape' must + // have a layout (LayoutUtil::HasLayout). + tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 74fa0b2f2e..9e3f06e527 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -694,9 +694,9 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { return LayoutUtil::ValidateLayoutInShape(shape); } -/* static */ Shape ShapeUtil::ChangeElementType(const Shape& shape, +/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original, PrimitiveType type) { - Shape new_shape = shape; + Shape new_shape = original; new_shape.set_element_type(type); return new_shape; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 2ea1bd95cb..df5b450438 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -170,7 +170,7 @@ class ShapeUtil { // As above, but for program shapes, returns a string for the form: // // (param_name: f32[42x12], ...) -> f32[24x42] - static string HumanString(const ProgramShape& shape); + static string HumanString(const ProgramShape& program_shape); // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index 5fa2211ac6..f9d25945bc 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -32,26 +32,26 @@ namespace { class Base1 { public: virtual ~Base1() {} - int pad; + int pad_; }; class Base2 { public: virtual ~Base2() {} - int yetotherpad; + int yetotherpad_; }; class Derived : public Base1, public Base2 { public: ~Derived() override {} - int evenmorepad; + int evenmorepad_; }; class CopyNoAssign { public: - explicit CopyNoAssign(int value) : foo(value) {} - CopyNoAssign(const CopyNoAssign& other) : foo(other.foo) {} - int foo; + explicit CopyNoAssign(int value) : foo_(value) {} + CopyNoAssign(const CopyNoAssign& other) : foo_(other.foo_) {} + int foo_; private: const CopyNoAssign& operator=(const CopyNoAssign&); @@ -253,7 +253,7 @@ TEST(StatusOr, TestCopyCtorNonAssignable) { StatusOr<CopyNoAssign> original(value); StatusOr<CopyNoAssign> copy(original); EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.ValueOrDie().foo, copy.ValueOrDie().foo); + EXPECT_EQ(original.ValueOrDie().foo_, copy.ValueOrDie().foo_); } TEST(StatusOr, TestCopyCtorStatusOKConverting) { diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index a1c53ef2aa..ac3f3f4c9d 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -61,6 +61,15 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) { error_spec_); } +XLA_TEST_F(Bfloat16Test, LogOperation) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(4.0f)); + builder.Log(x); + + ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(1.387f), {}, + error_spec_); +} + XLA_TEST_F(Bfloat16Test, NegateScalarF16) { ComputationBuilder builder(client_, TestName()); builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f))); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 1d27880fb1..d8fe12a72d 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -194,7 +194,7 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice<GlobalData*> arguments); void ComputeAndCompareTuple( ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error); + tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error); // Convenience method for running a built computation and comparing the result // with the HloEvaluator. diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index bfb04fd9f9..680d790b57 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -561,5 +561,25 @@ TEST_F(DotOperationTest, TransposeFolding) { } } +XLA_TEST_F(DotOperationTest, DotGeneralUnimplemented) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR3FromArray3D<float>( + {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}); + auto rhs = builder.ConstantR3FromArray3D<float>( + {{{1.0, 0.0}, {0.0, 1.0}}, {{0.0, 1.0}, {1.0, 0.0}}}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(0); + builder.DotGeneral(lhs, rhs, dot_dnums); + + auto status = Execute(&builder, {}).status(); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Dot with batch dimensions not implemented.")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 22d2b917a1..89fa6ed9f7 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -76,8 +76,11 @@ class MultiOutputFusionTest : public HloTestBase { elem_shape2, HloOpcode::kAdd, broadcast, param1)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape2, HloOpcode::kSubtract, param1, broadcast)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape2, HloOpcode::kDot, sub, add2)); + HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -133,8 +136,11 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {size, 1}), add)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1}), HloOpcode::kDot, sub, reshape)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 0601a1466b..aa035f0ba5 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -962,68 +962,114 @@ struct R1ReduceWindowTestData { int64 base_bounds[1]; int64 window_bounds[1]; int64 strides[1]; - Padding padding; + int64 pad_low[1]; + int64 pad_high[1]; Reducer reducer; } kR1TestCases[] = { {/*base_bounds=*/{1}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{3}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{4}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{30}, + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{30}, /*strides=*/{27}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 17}, /*window_bounds=*/{7}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 17}, + /*window_bounds=*/{7}, /*strides=*/{64}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{32}, + /*pad_low=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{32}, /*strides=*/{56}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{3}, /*strides=*/{2}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{0}, + /*pad_high=*/{5}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{5}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kAdd}, }; string R1ReduceWindowTestDataToString( const ::testing::TestParamInfo<R1ReduceWindowTestData>& data) { string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(data.param.base_bounds, "x"), "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + tensorflow::str_util::Join(data.param.window_bounds, "x"), "__strides_", + tensorflow::str_util::Join(data.param.strides, "x"), "__pad_low_", + tensorflow::str_util::Join(data.param.pad_low, "x"), "__pad_high_", + tensorflow::str_util::Join(data.param.pad_high, "x"), "__reducer_", + data.param.reducer == kAdd ? "add" : "max"); return str; } @@ -1044,15 +1090,18 @@ TEST_P(R1ReduceWindowTest, DoIt) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_arg, client_->TransferToServer(*input_literal)); + std::vector<std::pair<int64, int64>> padding(1); + padding[0] = {param.pad_low[0], param.pad_high[0]}; + auto computation = param.reducer == kAdd ? CreateScalarAddComputation(F32, &b) : CreateScalarMaxComputation(F32, &b); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0<float>(kInitValue), - /*computation=*/computation, - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + b.ReduceWindowWithGeneralPadding( + /*operand=*/b.Parameter(0, input_literal->shape(), "p0"), + /*init_value=*/b.ConstantR0<float>(kInitValue), + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1062,7 +1111,8 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*init=*/kInitValue, /*reduce_func=*/reduce_func, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + /*stride=*/param.strides, + /*padding=*/padding); ComputeAndCompareR1<float>(&b, tensorflow::gtl::ArraySlice<float>(*expected), {input_arg.get()}, ErrorSpec(1e-3, 1e-3)); diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index c21124750a..4db566f784 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -211,6 +212,13 @@ class SliceR1Test : public ClientLibraryTestBase, } }; +string SliceR1TestDataToString(const ::testing::TestParamInfo<R1Spec>& data) { + const R1Spec& spec = data.param; + return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, + spec.slice_start, spec.slice_limit, + spec.slice_stride); +} + XLA_TEST_P(SliceR1Test, DoIt_F32) { Run<float>(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_F64) { Run<double>(GetParam()); } @@ -223,30 +231,66 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); } -INSTANTIATE_TEST_CASE_P( // - SliceR1TestInstantiation, // - SliceR1Test, // - ::testing::Values( // - R1Spec{10, 0, 0, 1}, // - R1Spec{10, 7, 7, 1}, // - R1Spec{10, 2, 4, 1}, // - R1Spec{10, 2, 4, 2}, // - R1Spec{10, 0, 10, 1}, // - R1Spec{1024, 1024 - 4, 1024, 1}, // - R1Spec{4096, 7, 7 + 1024, 1}, // - R1Spec{10, 0, 10, 2}, // - R1Spec{10, 0, 10, 3}, // - R1Spec{10, 0, 10, 4}, // - R1Spec{10, 0, 10, 5}, // - R1Spec{10, 0, 10, 10}, // - R1Spec{500, 200, 400, 7}, // - R1Spec{4096, 1, 4095, 3}, // - R1Spec{2047, 1024 - 24, 1024 + 160, 31}, // - R1Spec{2047, 1, 2046, 3 * 128}, // - R1Spec{4096, 1024 + 3, 4095, 500}, // - R1Spec{8192, 0, 8192, 1024 * 3 + 400} // - ) // +// Tests for R1 slice ops. +// The format for each testcase is {input size, start, limit, stride}. +// clang-format off +INSTANTIATE_TEST_CASE_P( + SliceR1TestInstantiation, + SliceR1Test, + ::testing::Values( + R1Spec{10, 0, 0, 1}, + R1Spec{10, 7, 7, 1}, + R1Spec{10, 0, 5, 1}, + R1Spec{10, 3, 5, 1}, + R1Spec{10, 0, 10, 1}, + R1Spec{1024, 0, 5, 1}, + R1Spec{1024, 3, 5, 1}, + R1Spec{1024 + 17, 0, 5, 1}, + R1Spec{1024 + 17, 3, 5, 1}, + R1Spec{1024 + 17, 1024, 1024 + 6, 1}, + R1Spec{1024 + 17, 1024 + 1, 1024 + 6, 1}, + R1Spec{1024, 1024 - 4, 1024, 1}, + R1Spec{4 * 1024, 7, 7 + 1024, 1}, + R1Spec{4 * 1024, 0, 4 * 1024, 1}, + R1Spec{4 * 1024, 1, 4 * 1024 - 1, 1}, + R1Spec{4 * 1024, 1024, 3 * 1024, 1}, + R1Spec{4 * 1024, 1024 + 1, 3 * 1024 - 1, 1}, + R1Spec{16 * 1024, 0, 5, 1}, + R1Spec{16 * 1024, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 0, 5, 1}, + R1Spec{16 * 1024 + 17, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024 + 1, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024, 4 * 1024 - 17, 8 * 1024 - 18, 1}, + R1Spec{64 * 1024, 0, 64 * 1024, 1}, + R1Spec{64 * 1024, 1, 64 * 1024 - 1, 1}, + R1Spec{64 * 1024, 1024, 63 * 1024, 1}, + R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1}, + R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}, +// TODO(b/69425338): This uses too much memory on GPU. +#ifndef XLA_TEST_BACKEND_GPU + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, +#endif + R1Spec{10, 2, 4, 2}, + R1Spec{10, 0, 10, 2}, + R1Spec{10, 0, 10, 3}, + R1Spec{10, 0, 10, 4}, + R1Spec{10, 0, 10, 5}, + R1Spec{10, 0, 10, 10}, + R1Spec{500, 200, 400, 7}, + R1Spec{4096, 1, 4095, 3}, + R1Spec{2047, 1024 - 24, 1024 + 160, 31}, + R1Spec{2047, 1, 2046, 3 * 128}, + R1Spec{4096, 1024 + 3, 4095, 500}, + R1Spec{8192, 0, 8192, 1024 * 3 + 400} + ), + SliceR1TestDataToString ); +// clang-format on struct R2Spec { int64 input_dim0; diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 49f673f5f0..f3f10517e3 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -357,8 +357,7 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } -// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. -TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { +TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { std::vector<Shape> shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -411,8 +410,7 @@ TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } -// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. -TEST_F(WhileTest, DISABLED_WhileWithPermutationAndVectorResult) { +TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { std::vector<Shape> shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index e595df3052..fe5d29a6b6 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -191,9 +191,9 @@ std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1, return output; } -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> p) { - for (int64 i = 0; i < p.size(); ++i) { - if (p[i] != i) { +bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> permutation) { + for (int64 i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) { return false; } } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 2ba1a2d904..6800c3d7fa 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -498,6 +498,23 @@ message CustomCallRequest { Shape shape = 4; } +message DotDimensionNumbers { + // The dimension numbers that represent the 'lhs' contracting dimensions. + repeated int64 lhs_contracting_dimensions = 1; + // The dimension numbers that represent the 'rhs' contracting dimensions. + repeated int64 rhs_contracting_dimensions = 2; + // The dimension numbers that represent the 'lhs' batch dimensions. + repeated int64 lhs_batch_dimensions = 3; + // The dimension numbers that represent the 'rhs' batch dimensions. + repeated int64 rhs_batch_dimensions = 4; +}; + +message DotRequest { + ComputationDataHandle lhs = 2; + ComputationDataHandle rhs = 3; + DotDimensionNumbers dimension_numbers = 4; +} + message MapRequest { repeated ComputationDataHandle operands = 2; ComputationHandle to_apply = 3; @@ -732,9 +749,6 @@ enum BinaryOperation { BINOP_LT = 9; BINOP_NE = 10; - // Dot product, matrix multiply. - BINOP_DOT = 12; - // Element-wise maximum. BINOP_MAX = 14; @@ -885,6 +899,7 @@ message OpRequest { ConvolveRequest convolve_request = 8; CrossReplicaSumRequest cross_replica_sum_request = 9; CustomCallRequest custom_call_request = 10; + DotRequest dot_request = 43; DynamicSliceRequest dynamic_slice_request = 11; DynamicUpdateSliceRequest dynamic_update_slice_request = 12; GetTupleElementRequest get_tuple_element_request = 13; @@ -914,7 +929,7 @@ message OpRequest { BatchNormInferenceRequest batch_norm_inference_request = 38; FftRequest fft_request = 41; ConvertRequest bitcast_convert_request = 42; - // Next: 43 + // Next: 44 } } diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md index f49e5857fe..c7c128bf14 100644 --- a/tensorflow/contrib/android/README.md +++ b/tensorflow/contrib/android/README.md @@ -15,9 +15,9 @@ For prebuilt libraries, see the page for a recent build. The TensorFlow Inference Interface is also available as a -[JCenter package](https://bintray.com/google/tensorflow/tensorflow-android) and -can be included quite simply in your android project with a couple of lines in -the project's `build.gradle` file: +[JCenter package](https://bintray.com/google/tensorflow/tensorflow) +(see the tensorflow-android directory) and can be included quite simply in your +android project with a couple of lines in the project's `build.gradle` file: ``` allprojects { diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index aba356d616..a115d1610e 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -34,6 +34,8 @@ add_library(lib_tf STATIC IMPORTED ) set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION ${PREBUILT_DIR}/lib/libtensorflow-core.a) # Change to compile flags should be replicated into bazel build file +# TODO: Consider options other than -O2 for binary size. +# e.g. -Os for gcc, and -Oz for clang. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \ -std=c++11 -fno-rtti -fno-exceptions \ -O2 -Wno-narrowing -fomit-frame-pointer \ diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h index 6ed177e001..9e32bee505 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -208,6 +208,8 @@ class ASBSQueue : public BatchScheduler<TaskType> { // place any more tasks in this batch. void ReleaseBatch(const ASBSBatch<TaskType>* batch); + size_t max_task_size() const override { return options_.max_batch_size; } + private: std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_; const QueueOptions options_; diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc index a07cd6d834..e2aac54eeb 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc @@ -186,6 +186,7 @@ TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { queue_options.max_enqueued_batches = 2; TF_ASSERT_OK( scheduler->AddQueue(queue_options, queue_0_callback, &queue_0)); + EXPECT_EQ(10, queue_0->max_task_size()); queue_options.max_batch_size = 0; // Queue must have max_batch_size > 0. EXPECT_FALSE( diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h index 9d3805fbaf..91065db249 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler.h +++ b/tensorflow/contrib/batching/basic_batch_scheduler.h @@ -192,6 +192,10 @@ class BasicBatchScheduler : public BatchScheduler<TaskType> { size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; + size_t max_task_size() const override { + return shared_scheduler_queue_->max_task_size(); + } + private: explicit BasicBatchScheduler( std::unique_ptr<BatchScheduler<TaskType>> shared_scheduler_queue); diff --git a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc index e020301795..187823151c 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc @@ -73,6 +73,7 @@ TEST(BasicBatchSchedulerTest, Basic) { std::unique_ptr<BasicBatchScheduler<FakeTask>> scheduler; TF_ASSERT_OK( BasicBatchScheduler<FakeTask>::Create(options, callback, &scheduler)); + EXPECT_EQ(10, scheduler->max_task_size()); EXPECT_EQ(0, scheduler->NumEnqueuedTasks()); EXPECT_EQ(3 * 10, scheduler->SchedulingCapacity()); TF_ASSERT_OK(ScheduleTask(3, scheduler.get())); diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index a5072f439a..e18cf6c350 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -178,6 +178,10 @@ class BatchScheduler { // This method is useful for monitoring, or for guaranteeing a future slot in // the schedule (but being mindful about the caveats listed above). virtual size_t SchedulingCapacity() const = 0; + + // Returns the maximum allowed size of tasks submitted to the scheduler. (This + // is typically equal to a configured maximum batch size.) + virtual size_t max_task_size() const = 0; }; ////////// diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h index 41a3f99137..1d2158062e 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/shared_batch_scheduler.h @@ -248,6 +248,9 @@ class Queue { // BatchScheduler::SchedulingCapacity(). size_t SchedulingCapacity() const; + // Returns the maximum allowed size of tasks submitted to the queue. + size_t max_task_size() const { return options_.max_batch_size; } + // Called by a thread that is ready to process a batch, to request one from // this queue. Either returns a batch that is ready to be processed, or // nullptr if the queue declines to schedule a batch at this time. If it @@ -338,6 +341,8 @@ class QueueHandle : public BatchScheduler<TaskType> { size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; + size_t max_task_size() const override { return queue_->max_task_size(); } + private: // The scheduler that owns 'queue_'. std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler_; diff --git a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc index 3e924ae5f1..3ac79a8fdc 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc @@ -429,6 +429,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) { queue_options.max_enqueued_batches = max_enqueued_batches; std::unique_ptr<BatchScheduler<FakeTask>> queue; TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue)); + EXPECT_EQ(2, queue->max_task_size()); EXPECT_EQ(0, queue->NumEnqueuedTasks()); EXPECT_EQ(max_enqueued_batches * 2, queue->SchedulingCapacity()); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc index 0d46565a19..ccee9530b6 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc @@ -97,7 +97,7 @@ class IndicesRowIterator } bool operator<(const IndicesRowIterator& other) const { - return (row_idx_ < other.row_idx_); + return (row_idx_ < other.row_idx_); } bool operator==(const IndicesRowIterator& other) const { diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 7e8e15e7d8..294e04002a 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -45,6 +45,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token, epsilon, num_quantiles, + max_elements=None, name=None, container=None): """Creates a QuantileAccumulator object. @@ -53,6 +54,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token: The initial value for the stamp token. epsilon: Error bound on the quantile computation. num_quantiles: Number of quantiles to produce from the final summary. + max_elements: Maximum number of elements added to the accumulator. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` """ @@ -67,6 +69,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): self._quantile_accumulator_handle, init_stamp_token, epsilon=epsilon, + max_elements=max_elements, num_quantiles=num_quantiles) is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( self._quantile_accumulator_handle) diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 155c91cb97..0508006047 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 93815892dddafe9146a5f7e7042281d59d0f4323) +set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index 594c2492d4..aaae18a313 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -158,12 +158,21 @@ if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/netbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/freebsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/openbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) endif () endif () diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 5c01ca382f..e4213ea2a4 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -63,7 +63,7 @@ if (tensorflow_ENABLE_GPU) file(GLOB_RECURSE tf_core_gpu_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc" "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu/cupti_wrapper.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc" + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.h" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.cc" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index c607546f4a..5ec1a8d04f 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -211,7 +211,7 @@ if (NOT tensorflow_ENABLE_GPU) list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) else() file(GLOB tf_core_platform_srcs_exclude - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc") + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc") list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 0128946e45..819b6213ea 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -899,6 +899,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc" "${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h" diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 18b71d1f9a..2e3ee2c96b 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -225,6 +225,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py" "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/backend_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py" # Float division by zero "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" # Flaky, for unknown reasons. Cannot reproduce in terminal. Revisit once we can get stack traces. diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index d060eda0a7..bae66ffd42 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -225,6 +225,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, new_original_op, op_def) #Use Graph's hidden methods to add the op + to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index f7d8a084d9..3b1c33063f 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -18,6 +18,7 @@ py_library( "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 7c6244f22b..c9ad091bd4 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -66,6 +66,7 @@ from tensorflow.contrib.data.python.ops.readers import TextLineDataset from tensorflow.contrib.data.python.ops.readers import TFRecordDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan +from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.python.data.ops.iterator_ops import Iterator # pylint: enable=unused-import diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 1d4817fa26..4112de31c1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -277,7 +277,7 @@ py_test( py_test( name = "map_dataset_op_test", - size = "small", + size = "medium", srcs = ["map_dataset_op_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -419,12 +419,14 @@ py_test( py_test( name = "shuffle_dataset_op_test", - size = "small", + size = "medium", srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 6b5b53cc0f..ba1be0690f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -22,8 +22,10 @@ import os import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -156,6 +158,13 @@ class ShuffleDatasetTest(test.TestCase): for i in range(5): self.assertEqual(10, counts[i]) + def testSeedNoneSeed2NonNone(self): + with self.assertRaises(ValueError): + dataset_ops.ShuffleDataset(dataset_ops.Dataset.range(5), + buffer_size=1, + seed=None, + seed2=10) + class ShuffleDatasetSerializationTest(test.TestCase): @@ -474,5 +483,76 @@ class ShuffleDatasetSerializationTest(test.TestCase): self.assertEqual(expected_outputs_sorted, sorted(actual)) +class ShuffleAndRepeatTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed, count=5): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + + def testCorrectOutput(self): + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertSequenceEqual( + sorted(output), sorted( + np.array([range(20) for _ in range(5)]).flatten())) + for i in range(5): + self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20)) + + def testReshuffling(self): + # Check that the output orders of different epochs are indeed different. + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + for i in range(4): + epoch1 = output[i * 20:(i + 1) * 20] + epoch2 = output[(i + 1) * 20:(i + 2) * 20] + self.assertNotEqual(epoch1, epoch2) + + def testSameOrderForSameSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertEqual(output1, output2) + + def testDifferentOrderForDifferentSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountNone(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountMinusOne(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testInfiniteOutputs(self): + # Asserting that the iterator is exhausted after producing 100 items should + # fail. + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 25ed58cdf5..1f35ee056b 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -41,6 +41,25 @@ py_library( ) py_library( + name = "random_ops", + srcs = [ + "random_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( name = "readers", srcs = [ "readers.py", @@ -63,6 +82,19 @@ py_library( ) py_library( + name = "shuffle_ops", + srcs = [ + "shuffle_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":random_ops", + ":transformation_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_library( name = "transformation_ops", srcs = [ "batching.py", diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py new file mode 100644 index 0000000000..7d727165fe --- /dev/null +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -0,0 +1,67 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Datasets for random number generators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class RandomDataset(dataset_ops.Dataset): + """A `Dataset` of pseudorandom values.""" + + def __init__(self, seed=None): + """A `Dataset` of pseudorandom values.""" + super(RandomDataset, self).__init__() + seed, seed2 = random_seed.get_seed(seed) + if seed is None: + self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") + else: + self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") + if seed2 is None: + self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") + else: + self._seed2 = ops.convert_to_tensor( + seed2, dtype=dtypes.int64, name="seed2") + + def _as_variant_tensor(self): + return gen_dataset_ops.random_dataset( + seed=self._seed, + seed2=self._seed2, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.int64 diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py new file mode 100644 index 0000000000..460732d65e --- /dev/null +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental shuffle ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import random_ops +from tensorflow.python.data.ops import dataset_ops + + +def shuffle_and_repeat(buffer_size, count=None, seed=None): + """Shuffles and repeats a Dataset returning a new permutation for each epoch. + + `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))` + + is equivalent to + + `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)` + + The difference is that the latter dataset is not serializable. So, + if you need to checkpoint an input pipeline with reshuffling you must use + this implementation. + + Args: + buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the + maximum number elements that will be buffered when prefetching. + count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + number of times the dataset should be repeated. The default behavior + (if `count` is `None` or `-1`) is for the dataset be repeated + indefinitely. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + def _apply_fn(dataset): # pylint: disable=missing-docstring + random_ds = random_ops.RandomDataset(seed).apply( + batching.batch_and_drop_remainder(2)) + if count is not None and count is not -1: + random_ds = random_ds.take(count) + + def map_fn(seeds): + return dataset_ops.ShuffleDataset( + input_dataset=dataset, + buffer_size=buffer_size, + seed=seeds[0], + reshuffle_each_iteration=False, + seed2=seeds[1]) + + return random_ds.flat_map(map_fn) + + return _apply_fn diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 38b3a23c2d..49451446b5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -28,8 +28,19 @@ from tensorflow.python.ops.distributions.bijector_test_util import assert_biject from tensorflow.python.platform import test -class ReshapeBijectorTest(test.TestCase): - """Tests correctness of the reshape transformation.""" +class _ReshapeBijectorTest(object): + """Base class for testing the reshape transformation. + + Methods defined in this class call a method self.build_shapes() that + is implemented by subclasses defined below, returning respectively + ReshapeBijectorTestStatic: static shapes, + ReshapeBijectorTestDynamic: shape placeholders of known ndims, and + ReshapeBijectorTestDynamicNdims: shape placeholders of unspecified ndims, + so that each test in this base class is automatically run over all + three cases. The subclasses also implement assertRaisesError to test + for either Python exceptions (in the case of static shapes) or + TensorFlow op errors (dynamic shapes). + """ def setUp(self): self._rng = np.random.RandomState(42) @@ -40,9 +51,10 @@ class ReshapeBijectorTest(test.TestCase): expected_y = np.reshape(expected_x, [4, 6]) with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,]) bijector = Reshape( - event_shape_out=[6,], - event_shape_in=[3, 2], + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) (x_, y_, @@ -52,66 +64,23 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.forward_log_det_jacobian(expected_x), bijector.inverse_log_det_jacobian(expected_y), - )) + ), feed_dict=feed_dict) self.assertEqual("reshape", bijector.name) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(0., fldj_, rtol=1e-6, atol=0) self.assertAllClose(0., ildj_, rtol=1e-6, atol=0) - def testEventShapeDynamicNdims(self): - """Check forward/inverse shape methods with dynamic ndims.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_ph = array_ops.placeholder(dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_ph = array_ops.placeholder(dtype=dtypes.int32) - - bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, validate_args=True) - - # using the _tensor methods, we should always get a fully-specified - # result since these are evaluated at graph runtime. - with self.test_session() as sess: - (shape_out_, - shape_in_) = sess.run(( - bijector.forward_event_shape_tensor(shape_in), - bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeDynamic(self): - """Check shape methods with static ndims but dynamic shape.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_partial = tensor_shape.TensorShape([None,]) - shape_in_ph = array_ops.placeholder( - shape=[1,], dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_partial = tensor_shape.TensorShape([None, None]) - shape_out_ph = array_ops.placeholder( - shape=[2,], dtype=dtypes.int32) + def testEventShapeTensor(self): + """Test event_shape_tensor methods when even ndims may be dynamic.""" + shape_in_static = [2, 3] + shape_out_static = [6,] + shape_in, shape_out, feed_dict = self.build_shapes(shape_in_static, + shape_out_static) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, - validate_args=True) - - # if event shapes are not statically available, should - # return partially-specified TensorShapes. - self.assertAllEqual( - bijector.forward_event_shape(shape_in).as_list(), - shape_out_partial.as_list()) - self.assertAllEqual( - bijector.inverse_event_shape(shape_out).as_list(), - shape_in_partial.as_list()) + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) # using the _tensor methods, we should always get a fully-specified # result since these are evaluated at graph runtime. @@ -120,42 +89,9 @@ class ReshapeBijectorTest(test.TestCase): shape_in_) = sess.run(( bijector.forward_event_shape_tensor(shape_in), bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeStatic(self): - """Check shape methods when shape is statically known.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_out = tensor_shape.TensorShape([2, 3]) - - bijector_static = Reshape( - event_shape_out=shape_out, - event_shape_in=shape_in, - validate_args=True) - - # test that forward_ and inverse_event_shape do sensible things - # when shapes are statically known. - self.assertEqual( - bijector_static.forward_event_shape(shape_in), - shape_out) - self.assertEqual( - bijector_static.inverse_event_shape(shape_out), - shape_in) - - with self.test_session() as sess: - (shape_out_static_, - shape_in_static_, - ) = sess.run(( - bijector_static.forward_event_shape_tensor(shape_in), - bijector_static.inverse_event_shape_tensor(shape_out), - )) - self.assertAllEqual(shape_out, shape_out_static_) - self.assertAllEqual(shape_in, shape_in_static_) + ), feed_dict=feed_dict) + self.assertAllEqual(shape_out_static, shape_out_) + self.assertAllEqual(shape_in_static, shape_in_) def testScalarReshape(self): """Test reshaping to and from a scalar shape ().""" @@ -166,11 +102,11 @@ class ReshapeBijectorTest(test.TestCase): expected_x_scalar = np.random.randn(1,) expected_y_scalar = expected_x_scalar[0] + shape_in, shape_out, feed_dict = self.build_shapes([], [1,]) with self.test_session() as sess: bijector = Reshape( - event_shape_out=[], - event_shape_in=[1,], validate_args=True) - + event_shape_out=shape_in, + event_shape_in=shape_out, validate_args=True) (x_, y_, x_scalar_, @@ -180,53 +116,178 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.inverse(expected_y_scalar), bijector.forward(expected_x_scalar), - )) + ), feed_dict=feed_dict) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0) self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0) - def testRaisesOpError(self): - x1 = np.random.randn(4, 2, 3) - x2 = np.random.randn(4, 3, 2) - x3 = np.random.randn(4, 5, 1, 1) + def testMultipleUnspecifiedDimensionsOpError(self): with self.test_session() as sess: - shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32) - shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32) + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,]) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) - with self.assertRaisesOpError( + with self.assertRaisesError( + "elements must have at most one `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + def testInvalidDimensionsOpError(self): + + with self.test_session() as sess: + + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( + "elements must be either positive integers or `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + def testValidButNonMatchingInputOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + # Here we pass in a tensor (x) whose shape is compatible with + # the output shape, so tf.reshape will throw no error, but + # doesn't match the expected input shape. + with self.assertRaisesError( "Input `event_shape` does not match `event_shape_in`."): - sess.run(bijector.forward(x2), - feed_dict={shape_out_ph: [1, 6, 1], - shape_in_ph: [2, 3]}) + sess.run(bijector.forward(x), + feed_dict=feed_dict) - with self.assertRaisesOpError( - "event_shape_out entries must be positive."): - sess.run(bijector.forward(x1), - feed_dict={shape_out_ph: [-1, -1, 6], - shape_in_ph: [2, 3]}) + def testValidButNonMatchingInputPartiallySpecifiedOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( + "Input `event_shape` does not match `event_shape_in`."): + sess.run(bijector.forward(x), + feed_dict=feed_dict) + + def testInputOutputMismatchOpError(self): + x1 = np.random.randn(4, 2, 3) + x2 = np.random.randn(4, 1, 1, 5) + + with self.test_session() as sess: + shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3], + [1, 1, 5]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) # test that *all* methods check basic assertions - fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]} - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): + with self.assertRaisesError( + "Input to reshape is a tensor with"): sess.run(bijector.forward(x1), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse(x3), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse_log_det_jacobian(x3), - feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.forward_log_det_jacobian(x1), - feed_dict=fd_mismatched) + with self.assertRaisesError( + "Input to reshape is a tensor with"): + sess.run(bijector.inverse(x2), feed_dict=fd_mismatched) + + def testOneShapePartiallySpecified(self): + expected_x = np.random.randn(4, 6) + expected_y = np.reshape(expected_x, [4, 2, 3]) + + with self.test_session() as sess: + # one of input/output shapes is partially specified + shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testBothShapesPartiallySpecified(self): + expected_x = np.random.randn(4, 2, 3) + expected_y = np.reshape(expected_x, [4, 3, 2]) + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testDefaultVectorShape(self): + expected_x = np.random.randn(4, 4) + expected_y = np.reshape(expected_x, [4, 2, 2]) + with self.test_session() as sess: + _, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2]) + bijector = Reshape(shape_out, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def build_shapes(self, *args, **kwargs): + raise NotImplementedError("Subclass failed to implement `build_shapes`.") + + +class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_static = shape_in + shape_out_static = shape_out + feed_dict = {} + return shape_in_static, shape_out_static, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesRegexp(Exception, msg) + + def testEventShape(self): + shape_in_static = tensor_shape.TensorShape([2, 3]) + shape_out_static = tensor_shape.TensorShape([6,]) + bijector = Reshape( + event_shape_out=shape_out_static, + event_shape_in=shape_in_static, validate_args=True) + + # test that forward_ and inverse_event_shape do sensible things + # when shapes are statically known. + self.assertEqual( + bijector.forward_event_shape(shape_in_static), + shape_out_static) + self.assertEqual( + bijector.inverse_event_shape(shape_out_static), + shape_in_static) def testBijectiveAndFinite(self): x = np.random.randn(4, 2, 3) @@ -238,5 +299,32 @@ class ReshapeBijectorTest(test.TestCase): validate_args=True) assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=(len(shape_in),), + dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=(len(shape_out),), + dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + +class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py index b84502003a..0fe9f6aa78 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py @@ -48,7 +48,9 @@ class AbsoluteValue(bijector.Bijector): ```python - abs = ds.bijectors.AbsoluteValue() + tfd = tf.contrib.distributions + + abs = tfd.bijectors.AbsoluteValue() abs.forward([-1., 0., 1.]) ==> [1., 0., 1.] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py index ae14288393..f51c48d2dd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py @@ -124,17 +124,17 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): #### Example Use ```python - ds = tf.contrib.distributions - bs = tf.contrib.distributions.bijectors + tfd = tf.contrib.distributions + tfb = tfd.bijectors dims = 5 # A common choice for a normalizing flow is to use a Gaussian for the base # distribution. (However, any continuous distribution would work.) E.g., - maf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( + maf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( hidden_layers=[512, 512])), event_shape=[dims]) @@ -143,10 +143,10 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. # [1] also describes an "Inverse Autoregressive Flow", e.g., - iaf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.Invert(bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( + iaf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( hidden_layers=[512, 512]))), event_shape=[dims]) @@ -158,10 +158,10 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): # poor choice. Here's an example of using a "shift only" version and with a # different number/depth of hidden layers. shift_only = True - maf_no_scale_hidden2 = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - bs.masked_autoregressive_default_template( + maf_no_scale_hidden2 = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + tfb.masked_autoregressive_default_template( hidden_layers=[32], shift_only=shift_only), is_constant_jacobian=shift_only), diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py index b1d8f2f41b..8654cc39d0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py @@ -40,9 +40,9 @@ class Permute(bijector_lib.Bijector): """Permutes the rightmost dimension of a `Tensor`. ```python - bs = tf.contrib.distributions.bijectors + tfd = tf.contrib.distributions - reverse = bs.Permute(permutation=[2, 1, 0]) + reverse = tfd.bijectors.Permute(permutation=[2, 1, 0]) reverse.forward([-1., 0., 1.]) # ==> [1., 0., -1] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py index 93682639aa..55eca06312 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py @@ -36,70 +36,77 @@ __all__ = [ ] +def _static_ndims_from_shape(shape): + return shape.shape.with_rank_at_least(1)[0].value + + +def _ndims_from_shape(shape): + return array_ops.shape(shape)[0] + + class Reshape(bijector_lib.Bijector): """Reshapes the `event_shape` of a `Tensor`. The semantics generally follow that of `tf.reshape()`, with a few differences: - * The user must provide both the input and output shape, so that - the transformation can be inverted. - * The `Reshape` bijector automatically broadcasts over the leftmost - dimensions of its input (`sample_shape` and `batch_shape`); only - the rightmost `event_ndims_in` dimensions are reshaped. The - number of dimensions to reshape is inferred from the provided - `event_shape_in` (`event_ndims_in = len(event_shape_in)`). - * The `Reshape` bijector does not currently support - partially-specified shapes, i.e., those with a dimension - implicitly specified by `-1`. + + * The user must provide both the input and output shape, so that + the transformation can be inverted. If an input shape is not + specified, the default assumes a vector-shaped input, i.e., + event_shape_in = (-1,). + * The `Reshape` bijector automatically broadcasts over the leftmost + dimensions of its input (`sample_shape` and `batch_shape`); only + the rightmost `event_ndims_in` dimensions are reshaped. The + number of dimensions to reshape is inferred from the provided + `event_shape_in` (`event_ndims_in = len(event_shape_in)`). Example usage: ```python - bs = tf.contrib.distributions.bijectors + tfd = tf.contrib.distributions - reverse = bs.Reshape(event_shape_out=[1,2], - event_shape_in=[2,]) + r = tfd.bijectors.Reshape(event_shape_out=[1, -1]) - reverse.forward([1., 2.]) # shape [2,] - # ==> [[1., 2.]] # shape [1,2] + r.forward([3., 4.]) # shape [2] + # ==> [[3., 4.]] # shape [1, 2] - reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2] - # ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2] + r.forward([[1., 2.], [3., 4.]]) # shape [2, 2] + # ==> [[[1., 2.]], + # [[3., 4.]]] # shape [2, 1, 2] - reverse.inverse([[1., 2.]]) # shape [1,2] - # ==> [1., 2.] # shape [2,] + r.inverse([[3., 4.]]) # shape [1,2] + # ==> [3., 4.] # shape [2] - reverse.forward_log_det_jacobian(any_value) + r.forward_log_det_jacobian(any_value) # ==> 0. - reverse.inverse_log_det_jacobian(any_value) + r.inverse_log_det_jacobian(any_value) # ==> 0. ``` """ - def __init__(self, event_shape_out, event_shape_in, + def __init__(self, event_shape_out, event_shape_in=(-1,), validate_args=False, name=None): """Creates a `Reshape` bijector. Args: event_shape_out: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - transformed output. - event_shape_in: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - input. + representing the event shape of the transformed output. + event_shape_in: An optional `int`-like vector-shape `Tensor` + representing the event shape of the input. This is required in + order to define inverse operations; the default of (-1,) + assumes a vector-shaped input. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str`, name given to ops managed by this object. Raises: TypeError: if either `event_shape_in` or `event_shape_out` has - non-vector shape (`rank > 1`), or non-integer `dtype`. - ValueError: if either `event_shape_in` or `event_shape_out` - contains non-positive entries, or if their sizes do not match - (`prod(event_shape_in)` != `prod(event_shape_out)`), or if - their dimensionality(s) cannot be statically inferred. + non-integer `dtype`. + ValueError: if either of `event_shape_in` or `event_shape_out` + has non-vector shape (`rank > 1`), or if their sizes do not + match. """ with ops.name_scope(name, "reshape", values=[event_shape_out, event_shape_in]): @@ -111,105 +118,74 @@ class Reshape(bijector_lib.Bijector): name="event_shape_in", preferred_dtype=dtypes.int32) - # check that input shapes are positive integers assertions = [] - assertions += self._maybe_check_valid_shape( - event_shape_out, "event_shape_out", - validate_args=validate_args) - assertions += self._maybe_check_valid_shape( - event_shape_in, "event_shape_in", validate_args=validate_args) - - # check that prod(event_shape_in) = prod(event_shape_out) - assertions += self._maybe_check_matching_sizes( - event_shape_in, event_shape_out, validate_args=validate_args) + assertions.extend(self._maybe_check_valid_shape( + event_shape_out, validate_args)) + assertions.extend(self._maybe_check_valid_shape( + event_shape_in, validate_args)) self._assertions = assertions self._event_shape_in = event_shape_in self._event_shape_out = event_shape_out - self._event_shape_in_static = tensor_util.constant_value_as_shape( - event_shape_in) - self._event_shape_out_static = tensor_util.constant_value_as_shape( - event_shape_out) super(Reshape, self).__init__(is_constant_jacobian=True, validate_args=validate_args, name=name or "reshape") - def _maybe_check_valid_shape(self, shape_tensor, label, - validate_args=False): - """Check that a shape Tensor is int-type and positive.""" - - assertions = [] - - if not shape_tensor.dtype.is_integer: + def _maybe_check_valid_shape(self, shape, validate_args): + """Check that a shape Tensor is int-type and otherwise sane.""" + if not shape.dtype.is_integer: raise TypeError("{} dtype ({}) should be `int`-like.".format( - label, shape_tensor.dtype.name)) + shape.op.name, shape.dtype.name)) - shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor)) - if shape_rank is not None and shape_rank > 1: - raise ValueError("{} rank should be <= 1.".format(label)) + assertions = [] - s = tensor_util.constant_value(shape_tensor) - if s is not None: - if (s <= 0).any(): - raise ValueError("{} entries must be positive, but found {}".format( - label, s)) + ndims = array_ops.rank(shape) + ndims_ = tensor_util.constant_value(ndims) + if ndims_ is not None and ndims_ > 1: + raise ValueError("`{}` rank ({}) should be <= 1.".format( + shape.op.name, ndims_)) elif validate_args: - assertions.append(check_ops.assert_positive( - shape_tensor, message="{} entries must be positive".format(label))) - - return assertions - - def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out, - validate_args=False): - """Check that prod(event_shape_in)==prod(event_shape_out).""" + assertions.append(check_ops.assert_less_equal( + ndims, 1, message="`{}` rank should be <= 1.".format(shape.op.name))) - def _get_size_from_shape(shape): - """Computes size from a shape `Tensor`, statically if possible.""" - s = tensor_util.constant_value(shape) - if s is not None: - return [np.int32(np.prod(s))]*2 - return None, math_ops.reduce_prod(shape, name="size") - - # Ensure `event_shape_in` is compatible with `event_shape_out`. - event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_in) - event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_out) - - assertions = [] - if event_size_in_ is not None and event_size_out_ is not None: - if event_size_in_ != event_size_out_: + shape_ = tensor_util.constant_value_as_shape(shape) + if shape_.is_fully_defined(): + es = np.int32(shape_.as_list()) + if sum(es == -1) > 1: + raise ValueError( + "`{}` must have at most one `-1` (given {})" + .format(shape.op.name, es)) + if np.any(es < -1): raise ValueError( - "Input `event_size` ({}) does not match output `event_size` ({}).". - format(event_size_in, event_size_out_)) + "`{}` elements must be either positive integers or `-1`" + "(given {})." + .format(shape.op.name, es)) elif validate_args: - assertions.append(check_ops.assert_equal( - event_size_in, event_size_out, - message="Input/output `event_size`s do not match.")) - + assertions.extend([ + check_ops.assert_less_equal( + math_ops.reduce_sum( + math_ops.cast(math_ops.equal(shape, -1), dtypes.int32)), + 1, + message="`{}` elements must have at most one `-1`." + .format(shape.op.name)), + check_ops.assert_greater_equal( + shape, -1, + message="`{}` elements must be either positive integers or `-1`." + .format(shape.op.name)), + ]) return assertions def _reshape_helper(self, x, event_shape_in, event_shape_out): """Reshape only the event_shape of an input `Tensor`.""" - def _get_rank_from_shape(shape): - """Computes rank from a shape `Tensor`, statically if possible.""" - # Uses fact that rank is "shape of shape". - ndims = shape.shape.with_rank_at_least(1)[0].value - if ndims is not None: - return ndims, ndims - return None, array_ops.shape(shape)[0] - - event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in) + event_ndims_in_ = _static_ndims_from_shape(event_shape_in) + event_ndims_in = _ndims_from_shape(event_shape_in) + x_ndims_, x_ndims = x.shape.ndims, array_ops.rank(x) assertions = [] - # Ensure x.event_shape is compatible with event_shape_in. - if x.shape.ndims is not None: - x_ndims_, x_ndims = [x.shape.ndims]*2 - else: - x_ndims_, x_ndims = None, array_ops.rank(x) + # Ensure x.event_shape is compatible with event_shape_in. if (event_ndims_in_ is not None and x_ndims_ is not None and x.shape.with_rank_at_least(event_ndims_in_)[ @@ -223,13 +199,35 @@ class Reshape(bijector_lib.Bijector): event_shape_in_ = tensor_util.constant_value(event_shape_in) if x_event_shape_ is not None and event_shape_in_ is not None: - if not np.equal(x_event_shape_, event_shape_in_).all(): + # Compare the shape dimensions that are fully specified in the + # input (i.e., for which event_shape_in is not -1). If x_event_shape + # matches along all of these dimensions, it is compatible with + # the desired input shape and any further mismatches (i.e., + # imcompatibility with the desired *output* shape) will be + # caught inside of array_ops.reshape() below. + x_event_shape_specified_ = x_event_shape_[event_shape_in_ >= 0] + event_shape_in_specified_ = event_shape_in_[event_shape_in_ >= 0] + if not np.equal(x_event_shape_specified_, + event_shape_in_specified_).all(): raise ValueError( - "Input `event_shape` ({}) does not match `event_shape_in` ({}).". + "Input `event_shape` does not match `event_shape_in` ({} vs {}).". format(x_event_shape_, event_shape_in_)) elif self.validate_args: + # Similarly to the static case, we compare the shape dimensions + # that are fully specified in the input. We extract these + # dimensions using boolean_mask(), which requires that the mask + # have known ndims. We can assume that shape Tensors always have + # ndims==1 (this assumption is verified inside of + # _maybe_check_valid_shape), so the reshape operation is just a + # no-op that formally encodes this fact to make boolean_mask() + # happy. + event_shape_mask = array_ops.reshape(event_shape_in >= 0, [-1]) + x_event_shape_specified = array_ops.boolean_mask(x_event_shape, + event_shape_mask) + event_shape_in_specified = array_ops.boolean_mask(event_shape_in, + event_shape_mask) assertions.append(check_ops.assert_equal( - x_event_shape, event_shape_in, + x_event_shape_specified, event_shape_in_specified, message="Input `event_shape` does not match `event_shape_in`.")) if assertions: @@ -243,8 +241,19 @@ class Reshape(bijector_lib.Bijector): sample_and_batch_shape = sample_and_batch_shape[ :(ndims - math_ops.abs(event_ndims_in))] - new_shape = array_ops.concat( - [sample_and_batch_shape, event_shape_out], axis=0) + if (event_ndims_in_ is not None + and x_ndims_ is not None + and event_ndims_in_ == x_ndims_): + # Hack to allow forward/inverse_event_shape to do shape + # inference by calling this helper method with a dummy Tensor of + # shape event_shape_in. In this special case, + # sample_and_batch_shape will be empty so we can preserve static + # shape information by avoiding the concat operation below + # (which would be a no-op). + new_shape = event_shape_out + else: + new_shape = array_ops.concat( + [sample_and_batch_shape, event_shape_out], axis=0) return array_ops.reshape(x, new_shape) @@ -269,29 +278,37 @@ class Reshape(bijector_lib.Bijector): return constant_op.constant(0., dtype=x.dtype) def _forward_event_shape(self, input_shape): - self._event_shape_in_static.assert_is_compatible_with(input_shape) - return self._event_shape_out_static + # NOTE: this method and the other *_event_shape* methods + # compute shape by explicit transformation of a dummy + # variable. This approach is not generally recommended because it + # bloats the graph and could in general trigger side effects. + # + # In this particular case of the Reshape bijector, the + # forward and inverse transforms have no side effects, and we + # believe the reduction in code complexity from delegating the + # heavy lifting to tf.reshape() is worth the added graph ops. + # However, you should think hard before implementing this approach + # in other Bijectors; it is strongly preferred to compute + # shapes explicitly whenever it's feasible to do so. + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return dummy_reshaped.shape def _inverse_event_shape(self, output_shape): - self._event_shape_out_static.assert_is_compatible_with(output_shape) - return self._event_shape_in_static + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return dummy_reshaped.shape def _forward_event_shape_tensor(self, input_shape): - input_assertions = self._maybe_check_valid_shape( - input_shape, "input event shape", validate_args=self.validate_args) - input_assertions += self._maybe_check_matching_sizes( - input_shape, self._event_shape_out, - validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - input_assertions + self._assertions, self._event_shape_out) + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return array_ops.shape(dummy_reshaped) def _inverse_event_shape_tensor(self, output_shape): - - output_assertions = self._maybe_check_valid_shape( - output_shape, "output event shape", validate_args=self.validate_args) - output_assertions += self._maybe_check_matching_sizes( - output_shape, self._event_shape_in, validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - output_assertions + self._assertions, self._event_shape_in) + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return array_ops.shape(dummy_reshaped) diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index 8d59c1abfb..6f5d724a2a 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -43,16 +43,17 @@ class Cauchy(distribution.Distribution): The probability density function (pdf) is, ```none - pdf(x; loc, scale) = 1 / (pi * scale * (1 + ((x - loc) / scale)**2)) + pdf(x; loc, scale) = 1 / (pi scale (1 + z**2)) + z = (x - loc) / scale ``` where `loc` is the location, and `scale` is the scale. The Cauchy distribution is a member of the [location-scale family]( https://en.wikipedia.org/wiki/Location-scale_family), i.e. + `Y ~ Cauchy(loc, scale)` is equivalent to, ```none X ~ Cauchy(loc=0, scale=1) - Y ~ Cauchy(loc=loc, scale=scale) Y = loc + scale * X ``` @@ -61,14 +62,16 @@ class Cauchy(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Cauchy distribution. - dist = Cauchy(loc=0., scale=3.) + dist = tfd.Cauchy(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Cauchy distributions. - dist = Cauchy(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Cauchy(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -76,18 +79,17 @@ class Cauchy(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - - Arguments are broadcast when possible. - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Cauchy distributions. # Both have median 1, but different scales. - dist = tf.contrib.distributions.Cauchy(loc=1., scale=[11, 22.]) + dist = tfd.Cauchy(loc=1., scale=[11, 22.]) + # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. - dist.prob(3.0) + dist.prob(3.) ``` + """ def __init__(self, diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index 850d08d1bd..8049522e9f 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -290,8 +290,10 @@ class VectorDeterministic(_BaseDeterministic): #### Examples ```python + tfd = tf.contrib.distributions + # Initialize a single VectorDeterministic supported at [0., 2.] in R^2. - constant = tf.contrib.distributions.Deterministic([0., 2.]) + constant = tfd.Deterministic([0., 2.]) constant.prob([0., 2.]) ==> 1. constant.prob([0., 3.]) @@ -299,7 +301,7 @@ class VectorDeterministic(_BaseDeterministic): # Initialize a [3] batch of constants on R^2. loc = [[0., 1.], [2., 3.], [4., 5.]] - constant = constant_lib.VectorDeterministic(loc) + constant = tfd.VectorDeterministic(loc) constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]]) ==> [1., 0., 0.] ``` diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index ba8d3c639b..d0efaefb8e 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -62,15 +62,17 @@ class _Gumbel(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Gumbel distribution. - dist = tf.contrib.distributions.Gumbel(loc=0., scale=3.) + dist = tfd.Gumbel(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Gumbels. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Gumbel(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Gumbel(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -85,7 +87,7 @@ class _Gumbel(distribution.Distribution): ```python # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Gumbel(loc=1., scale=[11, 22.]) + dist = tfd.Gumbel(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index 6a74ca9a0a..cbce005013 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -68,11 +68,11 @@ class Independent(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Make independent distribution from a 2-batch Normal. - ind = ds.Independent( - distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]), + ind = tfd.Independent( + distribution=tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5]), reinterpreted_batch_ndims=1) # All batch dims have been "absorbed" into event dims. @@ -80,8 +80,8 @@ class Independent(distribution_lib.Distribution): ind.event_shape # ==> [2] # Make independent distribution from a 2-batch bivariate Normal. - ind = ds.Independent( - distribution=ds.MultivariateNormalDiag( + ind = tfd.Independent( + distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]), reinterpreted_batch_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 956dee38a3..ee4d86867d 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -88,8 +88,9 @@ class InverseGamma(distribution.Distribution): #### Examples ```python - dist = InverseGamma(concentration=3.0, rate=2.0) - dist2 = InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + tfd = tf.contrib.distributions + dist = tfd.InverseGamma(concentration=3.0, rate=2.0) + dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) ``` """ diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 48794a4882..473677f8d9 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -60,15 +60,17 @@ class Logistic(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Logistic distribution. - dist = tf.contrib.distributions.Logistic(loc=0., scale=3.) + dist = tfd.Logistic(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Logistics. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Logistic(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Logistic(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -76,14 +78,11 @@ class Logistic(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - Arguments are broadcast when possible. - - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Logistic(loc=1., scale=[11, 22.]) + dist = tfd.Logistic(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index e676931d91..f2d492f548 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -49,13 +49,13 @@ class Mixture(distribution.Distribution): ```python # Create a mixture of two Gaussians: - ds = tf.contrib.distributions + tfd = tf.contrib.distributions mix = 0.3 - bimix_gauss = ds.Mixture( - cat=ds.Categorical(probs=[mix, 1.-mix]), + bimix_gauss = tfd.Mixture( + cat=tfd.Categorical(probs=[mix, 1.-mix]), components=[ - ds.Normal(loc=-1., scale=0.1), - ds.Normal(loc=+1., scale=0.5), + tfd.Normal(loc=-1., scale=0.1), + tfd.Normal(loc=+1., scale=0.5), ]) # Plot the PDF. diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 5558ef0f25..5448918a50 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -43,15 +43,14 @@ class MixtureSameFamily(distribution.Distribution): #### Examples ```python - import matplotlib.pyplot as plt - ds = tf.contrib.distributions + tfd = tf.contrib.distributions ### Create a mixture of two scalar Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.Normal( + components_distribution=tfd.Normal( loc=[-1., 1], # One for each component. scale=[0.1, 0.5])) # And same here. @@ -63,14 +62,15 @@ class MixtureSameFamily(distribution.Distribution): # Plot PDF. x = np.linspace(-2., 3., int(1e4), dtype=np.float32) + import matplotlib.pyplot as plt plt.plot(x, gm.prob(x).eval()); ### Create a mixture of two Bivariate Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.MultivariateNormalDiag( + components_distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], # component 1 [1, -1]], # component 2 scale_identity_multiplier=[.3, .6])) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index 163cf75d99..e862552880 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -84,10 +84,10 @@ class MultivariateNormalDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -101,7 +101,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -119,7 +119,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate Gaussians. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 040bc23072..413e88f03a 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -86,7 +86,7 @@ class MultivariateNormalDiagPlusLowRank( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is @@ -97,7 +97,7 @@ class MultivariateNormalDiagPlusLowRank( [-1, 1], [2, -0.5]] # shape: [3, 2] m = [4., 5] # shape: [2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu scale_diag=d scale_perturb_factor=U, @@ -118,7 +118,7 @@ class MultivariateNormalDiagPlusLowRank( m = [[0.1, 0.2], [0.4, 0.5]] # shape: [b, r] = [2, 2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu, scale_perturb_factor=U, scale_perturb_diag=m) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index f9952b2069..8e69dadfb4 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -73,14 +73,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] cov = [[ 0.36, 0.12, 0.06], [ 0.12, 0.29, -0.13], [ 0.06, -0.13, 0.26]] - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance_matrix=cov) @@ -100,7 +100,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] covariance_matrix = ... # shape: [2, 3, 3], symmetric, positive definite. - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance=covariance_matrix) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 300bdd5f60..a739979289 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -90,8 +90,7 @@ class MultivariateNormalLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -103,9 +102,9 @@ class MultivariateNormalLinearOperator( # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale)) + scale=tf.linalg.LinearOperatorLowerTriangular(scale)) # Covariance agrees with cholesky(cov) parameterization. mvn.covariance().eval() @@ -122,9 +121,9 @@ class MultivariateNormalLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index 260dcc18f5..6c7dc4ca7a 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -76,12 +76,13 @@ class MultivariateNormalTriL( ``` Trainable (batch) lower-triangular matrices can be created with - `ds.matrix_diag_transform()` and/or `ds.fill_triangular()` + `tf.contrib.distributions.matrix_diag_transform()` and/or + `tf.contrib.distributions.fill_triangular()` #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -92,7 +93,7 @@ class MultivariateNormalTriL( # ==> [[ 0.6, 0. , 0. ], # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=scale) @@ -112,7 +113,7 @@ class MultivariateNormalTriL( mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal. - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=tril) @@ -124,9 +125,9 @@ class MultivariateNormalTriL( # Instantiate a "learnable" MVN. dims = 4 with tf.variable_scope("model"): - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=tf.get_variable(shape=[dims], dtype=tf.float32, name="mu"), - scale_tril=ds.fill_triangular( + scale_tril=tfd.fill_triangular( tf.get_variable(shape=[dims * (dims + 1) / 2], dtype=tf.float32, name="chol_Sigma"))) ``` diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index e1118ed431..2701c36fb5 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -107,10 +107,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions + # Create two batches of PoissonLogNormalQuadratureCompounds, one with # prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.` - pln = ds.PoissonLogNormalQuadratureCompound( + pln = tfd.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., quadrature_grid_and_probs=( diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index b05f15771a..c4b8f055b7 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -115,7 +115,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. - Default is `ds.Normal(0., 1.)`. + Default is `tf.distributions.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 92043d6a08..904724af42 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -188,8 +188,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and # another with mix_loc=[1]. In both cases, `K=2` and the affine @@ -197,20 +196,20 @@ class VectorDiffeomixture(distribution_lib.Distribution): # k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity # k=1: loc=[2.]*dims scale=LinOpDiag dims = 5 - vdm = ds.VectorDiffeomixture( + vdm = tfd.VectorDiffeomixture( mix_loc=[[0.], [1]], mix_scale=[1.], - distribution=ds.Normal(loc=0., scale=1.), + distribution=tfd.Normal(loc=0., scale=1.), loc=[ None, # Equivalent to `np.zeros(dims, dtype=np.float32)`. np.float32([2.]*dims), ], scale=[ - la.LinearOperatorScaledIdentity( + tf.linalg.LinearOperatorScaledIdentity( num_rows=dims, multiplier=np.float32(1.1), is_positive_definite=True), - la.LinearOperatorDiag( + tf.linalg.LinearOperatorDiag( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index 356d78b67a..526fe2d39a 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -89,14 +89,13 @@ class VectorExponentialDiag( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. # The first component has pdf exp{-x}, the second 0.5 exp{-x / 2} - vex = ds.VectorExponentialDiag(scale_diag=[1., 2.]) + vex = tfd.VectorExponentialDiag(scale_diag=[1., 2.]) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([3., 4.]).eval() # shape: [] @@ -107,7 +106,7 @@ class VectorExponentialDiag( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialDiag(loc, scale_diag) + vex = tfd.VectorExponentialDiag(loc, scale_diag) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index b313a851b3..9d5fd9ac41 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -107,16 +107,15 @@ class VectorExponentialLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. mat = [[1.0, 0.1], [0.1, 1.0]] - vex = ds.VectorExponentialLinearOperator( - scale=la.LinearOperatorFullMatrix(mat)) + vex = tfd.VectorExponentialLinearOperator( + scale=tf.linalg.LinearOperatorFullMatrix(mat)) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([1., 2.]).eval() # shape: [] @@ -127,9 +126,9 @@ class VectorExponentialLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialLinearOperator( + vex = tfd.VectorExponentialLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 0e3867809a..8dd983b750 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -101,10 +101,10 @@ class VectorLaplaceDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -118,7 +118,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -136,7 +136,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate VectorLaplace's. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index c7abdbb4ca..ec485c95c1 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -109,8 +109,7 @@ class VectorLaplaceLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate VectorLaplace with some desired covariance. mu = [1., 2, 3] @@ -124,9 +123,9 @@ class VectorLaplaceLinearOperator( # [ 0.1, -0.3, 0.4]]) # Divide scale by sqrt(2) so that the final covariance will be what we want. - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale / tf.sqrt(2))) + scale=tf.linalg.LinearOperatorLowerTriangular(scale / tf.sqrt(2.))) # Covariance agrees with cholesky(cov) parameterization. vla.covariance().eval() @@ -143,9 +142,9 @@ class VectorLaplaceLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 544a871070..e1ccf11645 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -143,7 +143,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is - `ds.Normal(0., 1.)`. + `tf.distributions.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 29d41ab81c..8c67647a61 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -91,14 +91,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): Extra leading dimensions, if provided, allow for batches. ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate vector Student's t-distribution. mu = [1., 2, 3] chol = [[1., 0, 0.], [1, 3, 0], [1, 2, 3]] - vt = ds.VectorStudentT(df=2, loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(df=2, loc=mu, scale_tril=chol) # Evaluate this on an observation in R^3, returning a scalar. vt.prob([-1., 0, 1]) @@ -107,7 +107,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): mu = [[1., 2, 3], [11, 22, 33]] chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal. - vt = ds.VectorStudentT(loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(loc=mu, scale_tril=chol) # Evaluate this on a two observations, each in R^3, returning a length two # tensor. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index bf2e883bc5..55d768044b 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -232,6 +232,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":network", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:constant_op", "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 0388aaa849..e3c13cbd2e 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -451,8 +451,30 @@ class Network(base.Layer): "at https://github.com/tensorflow/tensorflow/issues/new if this is " "important to you") + def add_loss(self, losses, inputs=None): + raise RuntimeError( + "add_loss is not supported in Network class yet. Please file an issue " + "at https://github.com/tensorflow/tensorflow/issues/new if this is " + "important to you") + + @property + def losses(self): + """Gather losses from `Layer`s in the `Network`. + + Note that when executing eagerly, `Layer.losses` evaluates + regularizers. When using graph execution, variable regularization ops have + already been created and are simply returned here. + + Returns: + A list of tensors. + """ + layer_losses = [] + for layer in self.layers: + layer_losses.extend(layer.losses) + return layer_losses + # TODO(allenl): Support other Layer methods needed for graph mode, such as for - # losses and updates + # updates class Sequential(Network): diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index e7835a63e6..3eb4f5f8b3 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import gc from tensorflow.contrib.eager.python import network +from tensorflow.contrib.layers.python.layers import regularizers from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test @@ -45,6 +46,22 @@ class MyNetwork(network.Network): return self.l1(x) +class RegularizedNetwork(network.Network): + + def __init__(self): + super(RegularizedNetwork, self).__init__() + self.l1 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0), + kernel_regularizer=regularizers.l1_regularizer(2.0))) + self.l2 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0))) + + def call(self, values): + return self.l2(self.l1(values)) + + class NetworkTest(test.TestCase): def _save_modify_load_network_built(self, net, global_step=None): @@ -485,6 +502,18 @@ class NetworkTest(test.TestCase): checked_ops=checked_ops) @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableRegularizers(self): + net = RegularizedNetwork() + net(constant_op.constant([[1.]])) + self.evaluate(net.variables[0].assign([[2.]])) + self.evaluate(net.variables[1].assign([3.])) + self.evaluate(net.variables[2].assign([[-2.]])) + self.evaluate(net.variables[3].assign([4.])) + self.assertAllEqual([4., 6., 8.], self.evaluate(net.losses)) + self.evaluate(net.variables[3].assign([5.])) + self.assertAllEqual([4., 6., 10.], self.evaluate(net.losses)) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testDuplicateNameError(self): one = constant_op.constant([[1.]]) net = MyNetwork(name="foo") diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 8395e2db5e..706a174efb 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -93,6 +93,7 @@ py_test( srcs_version = "PY2AND3", tags = [ "no_pip", + "notap", # b/62204861 "notsan", ], deps = [ @@ -346,7 +347,7 @@ py_library( cuda_py_test( name = "replicate_model_fn_test", - size = "small", + size = "medium", srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ "//tensorflow/python/estimator", diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index d9c83aa865..f5154231da 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -42,10 +42,49 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging +from tensorflow.python.training import device_setter as device_setter_lib from tensorflow.python.training import training_util -def replicate_model_fn(model_fn, optimizer_fn, devices=None): +class Mode(object): + """Modes for variables replication used for forcing a particular mode. + + Forcing a mode is meant for performance experimentation purposes rather than + for general use cases. + """ + + AUTO = 0 + """Use internal heuristics for choosing the best Mode value. + + This mode is supposed to be the most appropriate in most cases given what + is known about the system. + """ + # TODO(isaprykin): Query system configuration to choose modes other than + # `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often appropriate. + + SHARED_LOCAL_PARAMETER_SERVER = 2 + """Variables are placed on a single device and shared across all devices. + + Two ways to achieve this replication over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + """ + + SHARED_ROUND_ROBIN = 3 + """Variables are placed on all devices in a round-robin fashion. + + Every subsequent variable is placed on the next device. There is only one + copy of each variable that is shared across all devices. + """ + + # TODO(isaprykin): Implement `REPLICATED_ALL_REDUCE`. + REPLICATED_ALL_REDUCE = 3 + """Variables are mirrored on all devices.""" + + +def replicate_model_fn(model_fn, optimizer_fn, devices=None, mode=Mode.AUTO): """Replicate `Estimator.model_fn` over GPUs within a single host. The given `model_fn` specifies a single forward pass of a model. To replicate @@ -58,14 +97,11 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): optimizer. If `devices` are `None`, then all available GPUs are going to be used for - replication. If no GPUs are available, then the model is going to be - placed on the CPU. + replication: `devices=[<all available GPUs>]`. If no GPUs are available, + then the model is going to be placed on the CPU: `devices=['/device:CPU:0']`. - Two modes of local replication over available GPUs are supported: - 1) If exactly 1 GPU is detected, then variables and operations are placed - onto GPU. - 2) If more than 1 GPU is detected, then variables are going to be placed on - the CPU. Replicas of operations are placed on each individual GPU. + Varibles are placed on to `devices` according to the given `mode`. Operations + are going for each tower are going to be copied on each device. Here is an example of how one might use their `model_fn` to run over GPUs: ```python @@ -127,6 +163,8 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): argument can be used to replice only on the subset of available GPUs. If `None`, then all available GPUs are going to be used for replication. If no GPUs are available, then the model is going to be placed on the CPU. + mode: An optional argument that specifies the replication method used for + distributing variables across devices. Returns: A replicated version of the supplied `model_fn`. Returned function that @@ -137,16 +175,21 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): devices = _get_local_devices('GPU') or _get_local_devices('CPU') is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] - local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU') + consolidation_device = '/{}:0'.format('GPU' + if is_a_single_gpu_case else 'CPU') + + ps_devices = [consolidation_device] + if mode == Mode.SHARED_ROUND_ROBIN: + ps_devices = devices - tf_logging.info('Replicating the `model_fn` across {}. Local parameter ' - 'server device is going to be {}.'.format( - devices, local_ps_device)) + tf_logging.info('Replicating the `model_fn` across {}. Variables are going ' + 'to be placed on {}. Consolidation device is going to be {}.' + .format(devices, ps_devices, consolidation_device)) def replicated_model_fn(features, labels, mode, params=None, config=None): """Replicated version of `model_fn` to be used instead.""" feature_shards, label_shards = _split_batch( - features, labels, len(devices), device=local_ps_device) + features, labels, len(devices), device=consolidation_device) tower_specs = _get_loss_towers( model_fn=model_fn, mode=mode, @@ -155,17 +198,17 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): params=params, config=config, devices=devices, - local_ps_device=local_ps_device) + local_ps_devices=ps_devices) if mode == model_fn_lib.ModeKeys.TRAIN: train_op = _minimize_towers(tower_specs, _call_optimizer_fn(optimizer_fn, params)) return _train_spec( - tower_specs, train_op, aggregation_device=local_ps_device) + tower_specs, train_op, aggregation_device=consolidation_device) elif mode == model_fn_lib.ModeKeys.EVAL: - return _eval_spec(tower_specs, aggregation_device=local_ps_device) + return _eval_spec(tower_specs, aggregation_device=consolidation_device) elif mode == model_fn_lib.ModeKeys.PREDICT: - return _predict_spec(tower_specs, aggregation_device=local_ps_device) + return _predict_spec(tower_specs, aggregation_device=consolidation_device) return replicated_model_fn @@ -222,7 +265,7 @@ def _get_loss_towers(model_fn, params, config, devices, - local_ps_device, + local_ps_devices, name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): """Replicate the loss computation across devices.""" tower_specs = [] @@ -234,15 +277,22 @@ def _get_loss_towers(model_fn, if 'config' in model_fn_args: optional_params['config'] = copy.deepcopy(config) + # pylint: disable=protected-access + round_robin_strategy = device_setter_lib._RoundRobinStrategy( + num_tasks=len(local_ps_devices)) + # pylint: enable=protected-access + for i, device in enumerate(devices): is_the_first_tower = (i == 0) device_setter = _local_device_setter( - worker_device=device, ps_device=local_ps_device) + worker_device=device, + ps_devices=local_ps_devices, + ps_strategy=round_robin_strategy) - # We would like to preserve the names of the variables and ops that a user - # might be relying on. Names with prefix are going to resolve to variables - # and ops of the first tower. + # We would like to preserve the names of the variables and ops that the user + # might be relying on. Names without a prefix are going to resolve to + # variables and ops of the first tower. name_scope = name_scope_pattern if is_the_first_tower: name_scope = '' @@ -263,7 +313,7 @@ def _get_loss_towers(model_fn, return tower_specs -def _local_device_setter(ps_device, worker_device): +def _local_device_setter(worker_device, ps_devices, ps_strategy): """A device setter that puts distributes Var/Ops to PS/workers.""" ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] @@ -273,7 +323,7 @@ def _local_device_setter(ps_device, worker_device): node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def if node_def.op in ps_ops: ps_device_spec = framework_device.DeviceSpec.from_string( - '{}'.format(ps_device)) + '{}'.format(ps_devices[ps_strategy(op)])) ps_device_spec.merge_from(current_device) return ps_device_spec.to_string() diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index ffe69f89b4..662021853d 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -49,15 +49,29 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import device_setter from tensorflow.python.training import gradient_descent +# TODO(isaprykin): Parametrize all the tests on replicate_model_fn.Mode when +# it's supported. class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): def setUp(self): self._model_dir = tempfile.mkdtemp() - def test_complete_flow(self): + def test_complete_flow_with_mode_auto(self): + return self._complete_flow_with_mode(replicate_model_fn.Mode.AUTO) + + def test_complete_flow_with_mode_local_ps_server(self): + return self._complete_flow_with_mode( + replicate_model_fn.Mode.SHARED_LOCAL_PARAMETER_SERVER) + + def test_complete_flow_with_mode_round_robin(self): + return self._complete_flow_with_mode( + replicate_model_fn.Mode.SHARED_ROUND_ROBIN) + + def _complete_flow_with_mode(self, mode): n_classes = 3 input_dimension = 2 batch_size = 12 @@ -109,7 +123,8 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): model_fn=replicate_model_fn.replicate_model_fn( estimator.model_fn, optimizer_fn, - devices=['/gpu:0', '/gpu:1', '/gpu:2']), + devices=['/gpu:0', '/gpu:1', '/gpu:2'], + mode=mode), model_dir=estimator.model_dir, config=estimator.config, params=estimator.params) @@ -359,7 +374,7 @@ class GetLossTowersTest(test_util.TensorFlowTestCase): params=None, config=None, devices=['/gpu:0', '/gpu:1'], - local_ps_device='/gpu:0', + local_ps_devices=['/gpu:0'], name_scope_pattern='test_tower_{}') session.run(variables.global_variables_initializer()) @@ -382,6 +397,54 @@ class GetLossTowersTest(test_util.TensorFlowTestCase): c = variable_scope.get_variable('c', dtype=dtypes.float64) self.assertEqual(0.25, session.run(c)) + def test_variables_are_round_robined_correctly(self): + """Test that creates multiple variables and tests round-robin placement.""" + + def model_fn(mode, features, labels, params): + del params + for variable_name in ['a', 'b', 'c', 'd']: + c = variable_scope.get_variable( + variable_name, + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=math_ops.reduce_sum(loss)) + + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + model_fn, + mode=None, + features=[[0.6], [1.6], [2.6]], + labels=[[0.6], [0.6], [2.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1', '/gpu:3'], + local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 3) + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('/device:GPU:3', tower_specs[2].loss.device) + + with variable_scope.variable_scope('', reuse=True): + a = variable_scope.get_variable('a', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', a.device) + b = variable_scope.get_variable('b', dtype=dtypes.float64) + self.assertEqual('/device:GPU:1', b.device) + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:3', c.device) + d = variable_scope.get_variable('d', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', d.device) + class SplitBatchTest(test_util.TensorFlowTestCase): @@ -604,7 +667,7 @@ class PredictSpecTest(test_util.TensorFlowTestCase): params=None, config=None, devices=['/gpu:0', '/gpu:1'], - local_ps_device='/gpu:0', + local_ps_devices=['/gpu:0'], ) session.run(variables.global_variables_initializer()) @@ -850,25 +913,66 @@ class GetLocalDevicesTest(test_util.TensorFlowTestCase): class LocalDeviceSetterTest(test_util.TensorFlowTestCase): def test_vars_are_on_ps_but_ops_are_on_workers(self): + ps_devices = ['/device:GPU:3'] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + + local_device_setter = replicate_model_fn._local_device_setter( + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:3', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:3', b.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', c.device) + + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + def test_round_robin_placement(self): + ps_devices = [ + '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4' + ] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + local_device_setter = replicate_model_fn._local_device_setter( - ps_device='/device:GPU:3', worker_device='/device:GPU:2') + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') with ops_lib.device(local_device_setter): - c = variables.Variable(0.01) + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:0', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:1', b.device) + + c = variables.Variable(0.03) self.assertEqual('/device:GPU:3', c.device) - cc = variables.Variable(0.02) - self.assertEqual('/device:GPU:3', cc.device) + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) - ccc = variables.Variable(0.03) - self.assertEqual('/device:GPU:3', ccc.device) + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:4', c.device) + + d = variables.Variable(0.03) + self.assertEqual('/device:GPU:0', d.device) c_op = array_ops.concat(c, axis=0) self.assertEqual('/device:GPU:2', c_op.device) - cc_op = array_ops.concat(cc, axis=0) - self.assertEqual('/device:GPU:2', cc_op.device) - class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index dc5a04a0b1..eccce99071 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -155,7 +155,10 @@ tf_py_test( data = [ ":test_data", ], - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) py_library( diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index 871dff7bbe..daba965a98 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -26,6 +26,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio +from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video diff --git a/tensorflow/contrib/ffmpeg/decode_video_op_test.py b/tensorflow/contrib/ffmpeg/decode_video_op_test.py index 4d1fac4ef8..b43b6b8919 100644 --- a/tensorflow/contrib/ffmpeg/decode_video_op_test.py +++ b/tensorflow/contrib/ffmpeg/decode_video_op_test.py @@ -20,11 +20,9 @@ from __future__ import print_function import os.path -import six +import six # pylint: disable=unused-import from tensorflow.contrib import ffmpeg -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops from tensorflow.python.ops import image_ops from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -32,7 +30,8 @@ from tensorflow.python.platform import test class DecodeVideoOpTest(test.TestCase): - def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, index): + def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, + index): """Loads an video file and validates the output tensor. Args: @@ -40,6 +39,8 @@ class DecodeVideoOpTest(test.TestCase): width: The width of the video. height: The height of the video. frames: The frames of the video. + bmp_filename: The filename for the bmp file. + index: Index location inside the video. """ with self.test_session(): path = os.path.join(resource_loader.get_data_files_path(), 'testdata', @@ -48,7 +49,7 @@ class DecodeVideoOpTest(test.TestCase): contents = f.read() bmp_path = os.path.join(resource_loader.get_data_files_path(), 'testdata', - bmp_filename) + bmp_filename) with open(bmp_path, 'rb') as f: bmp_contents = f.read() @@ -58,7 +59,7 @@ class DecodeVideoOpTest(test.TestCase): video_op = ffmpeg.decode_video(contents) video = video_op.eval() self.assertEqual(video.shape, (frames, height, width, 3)) - self.assertAllEqual(video[index,:,:,:], image) + self.assertAllEqual(video[index, :, :, :], image) def testMp4(self): self._loadFileAndTest('small.mp4', 560, 320, 166, 'small_100.bmp', 99) diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 201774e1d0..1245f515fe 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -220,7 +220,8 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count, Status ReadInfoFile(const string& filename, uint32* width, uint32* height, uint32* frames) { string data; - ReadFileToString(Env::Default(), filename, &data); + TF_QCHECK_OK(ReadFileToString(Env::Default(), filename, &data)) + << "Could not read FFmpeg file: " << filename; bool in_output = false; bool in_mapping = false; uint32 frames_value = 0; @@ -377,7 +378,7 @@ Status ReadVideoFile(const string& filename, std::vector<uint8>* output_data, open(stderr_filename.c_str(), O_RDWR | O_CREAT | O_APPEND, 0600); if (fd < 0) { const int error = errno; - LOG(ERROR) << "FFmpeg stderr file coule not be created: " + LOG(ERROR) << "FFmpeg stderr file could not be created: " << strerror(error); ::_exit(error); } diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc index 39e7e90ccc..36fc71794b 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -23,6 +23,7 @@ #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 78ead471d2..08b5a6ea48 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py +from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index 6d5cde5c9e..a18ff2320d 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -150,5 +150,5 @@ def get_placeholders(graph): # The return value (a Tensor) of placeholder() is the # first output of this operation in fact. operations = graph.get_operations() - result = [i.outputs[0] for i in operations if i.type == 'Placeholder'] + result = [i.outputs[0] for i in operations if i.type == "Placeholder"] return result diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py index 0722fafc13..b8a6d109e1 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util_test.py +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -90,8 +90,9 @@ class GetPlaceholdersTest(test.TestCase): with ops.Graph().as_default() as g: placeholders = [array_ops.placeholder(dtypes.float32) for _ in range(5)] results = graph_util.get_placeholders(g) - self.assertEqual(sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access - sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access + self.assertEqual( + sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access + sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access if __name__ == '__main__': diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 88306094ab..5fec69ea43 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -493,6 +493,8 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>:: {{conv_input_rows, conv_input_cols}}, output_depth, {{filter_rows, filter_cols}}, + // TODO(yangzihao): Add support for arbitrary dilations for fused conv. + {{1, 1}}, // dilation_rows, dilation_cols {{row_stride, col_stride}}, {{padding_rows, padding_cols}}, conv_input->dtype(), diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h index dc43af1158..fa7a3c03aa 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -30,11 +30,12 @@ class FusedConvParameters : public ConvParameters { public: FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, int64 out_depths, const SpatialArray& filter, - const SpatialArray& stride, const SpatialArray& padding, - DataType dtype, int device_id, bool has_side_input, + const SpatialArray& dilation, const SpatialArray& stride, + const SpatialArray& padding, DataType dtype, + int device_id, bool has_side_input, ActivationMode activation_mode) - : ConvParameters(batch, in_depths, in, out_depths, filter, stride, - padding, dtype, device_id), + : ConvParameters(batch, in_depths, in, out_depths, filter, dilation, + stride, padding, dtype, device_id), activation_mode_(activation_mode), has_side_input_(has_side_input) { hash_code_ = Hash64Combine(hash_code_, has_side_input); diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 887ebc5a6c..6a56237f67 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -52,6 +52,7 @@ REGISTER_OP("FusedConv2DBiasActivation") .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") .Attr("activation_mode: {'Relu'} = 'Relu'") + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](shape_inference::InferenceContext* c) { using shape_inference::ShapeHandle; using shape_inference::DimensionHandle; @@ -151,6 +152,11 @@ REGISTER_OP("FusedConv2DBiasActivation") kernel_height, kernel_width, input_channels % 4 ]` activation_mode: The activation applied to the output. Currently must be "Relu". + dilations: 1-D tensor of length 4. The dilation factor for each dimension + of `input`. If set to k > 1, there will be k-1 skipped cells between + each filter element on that dimension. The dimension order is determined + by the value of `data_format`, see above for details. Dilations in the + batch and depth dimensions must be 1. )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 2a97a79070..14ac529665 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -173,6 +173,9 @@ def copy_op_handler(info, op, copy_shape=True): if op._original_op: op_._original_op = op._original_op + # Add op to the graph + info.graph_._add_op(op_) + return op_, op_.outputs diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index fbc192f1dc..6c1dd0ae40 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -580,6 +580,9 @@ class ConvDiagonalFactor(DiagonalFactor): # the target entry of _outputs_grads changes with idx.) with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): filter_height, filter_width, _, _ = self._filter_shape + + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. patches = array_ops.extract_image_patches( inputs, ksizes=[1, filter_height, filter_width, 1], @@ -739,6 +742,9 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): # TODO(jamesmartens): factor this patches stuff out into a utility function with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs): filter_height, filter_width, in_channels, _ = self._filter_shape + + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. patches = array_ops.extract_image_patches( self._inputs, ksizes=[1, filter_height, filter_width, 1], diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 226d933d85..092d418c3f 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -521,7 +521,7 @@ def sparse_column_with_integerized_feature(column_name, Args: column_name: A string defining sparse column name. - bucket_size: An int that is > 1. The number of buckets. It should be bigger + bucket_size: An int that is >= 1. The number of buckets. It should be bigger than maximum feature. In other words features in this column should be an int64 in range [0, bucket_size) combiner: A string specifying how to reduce if the sparse column is @@ -539,7 +539,7 @@ def sparse_column_with_integerized_feature(column_name, An integerized _SparseColumn definition. Raises: - ValueError: bucket_size is not greater than 1. + ValueError: bucket_size is less than 1. ValueError: dtype is not integer. """ return _SparseColumnIntegerized( diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 6cd586a5f0..6569b7ec9a 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2561,7 +2561,10 @@ def separable_convolution2d( regularizer=weights_regularizer, trainable=trainable, collections=weights_collections) - strides = [1, 1, stride_h, stride_w] if data_format.startswith('NC') else [1, stride_h, stride_w, 1] + strides = [1, 1, stride_h, + stride_w] if data_format.startswith('NC') else [ + 1, stride_h, stride_w, 1 + ] outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding, rate=utils.two_element_tuple(rate), diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index a05e464a26..ae64b75d93 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -3332,11 +3332,18 @@ class SeparableConv2dTest(test.TestCase): batch, height, width = 4, 10, 12 kernel_dim, stride = 3, 2 images = random_ops.random_uniform((batch, 3, height, width), seed=1) - output = layers_lib.separable_conv2d(images, num_outputs=num_filters, kernel_size=[kernel_dim, kernel_dim], - depth_multiplier=2, stride=stride, padding='VALID', data_format='NCHW') - self.assertListEqual( - output.get_shape().as_list(), [batch, correct_output_filters, - (height - kernel_dim + 1) // stride, (width - kernel_dim + 1) // stride]) + output = layers_lib.separable_conv2d( + images, + num_outputs=num_filters, + kernel_size=[kernel_dim, kernel_dim], + depth_multiplier=2, + stride=stride, + padding='VALID', + data_format='NCHW') + self.assertListEqual(output.get_shape().as_list(), [ + batch, correct_output_filters, (height - kernel_dim + 1) // stride, + (width - kernel_dim + 1) // stride + ]) class ScaleGradientTests(test.TestCase): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 94920db574..26bbcab307 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -461,6 +461,7 @@ py_test( size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], srcs_version = "PY2AND3", + tags = ["noasan"], deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index e0f2ef768b..cbc96e6edd 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -1,4 +1,19 @@ #!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + set -e make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 571d857be7..7fce1ba346 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h index 75b1f1da38..94046d9728 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h @@ -14,8 +14,8 @@ #import <UIKit/UIKit.h> -@interface AppDelegate : UIResponder <UIApplicationDelegate> +@interface AppDelegate : UIResponder<UIApplicationDelegate> -@property (strong, nonatomic) UIWindow *window; +@property(strong, nonatomic) UIWindow *window; @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm index 1e808eb976..d1215fa0bf 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm @@ -22,8 +22,7 @@ didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { UITabBarController *bar = [[UITabBarController alloc] init]; - [bar setViewControllers: - @[[[RunModelViewController alloc] init]]]; + [bar setViewControllers:@[ [[RunModelViewController alloc] init] ]]; bar.selectedIndex = 0; self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; self.window.rootViewController = bar; @@ -31,14 +30,19 @@ return YES; } -- (void)applicationWillResignActive:(UIApplication *)application {} +- (void)applicationWillResignActive:(UIApplication *)application { +} -- (void)applicationDidEnterBackground:(UIApplication *)application {} +- (void)applicationDidEnterBackground:(UIApplication *)application { +} -- (void)applicationWillEnterForeground:(UIApplication *)application {} +- (void)applicationWillEnterForeground:(UIApplication *)application { +} -- (void)applicationDidBecomeActive:(UIApplication *)application {} +- (void)applicationDidBecomeActive:(UIApplication *)application { +} -- (void)applicationWillTerminate:(UIApplication *)application {} +- (void)applicationWillTerminate:(UIApplication *)application { +} @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h index 4e1a83ccf5..a4b358b4eb 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h @@ -18,7 +18,7 @@ - (IBAction)getUrl:(id)sender; -@property (weak, nonatomic) IBOutlet UITextView *urlContentTextView; -@property (weak, nonatomic) IBOutlet UITextField *urlTextField; +@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView; +@property(weak, nonatomic) IBOutlet UITextField *urlTextField; @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm index 965d830105..0dafb1f61e 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -14,10 +14,10 @@ #import "RunModelViewController.h" -#include <fstream> -#include <iostream> #include <pthread.h> #include <unistd.h> +#include <fstream> +#include <iostream> #include <queue> #include <sstream> #include <string> @@ -30,7 +30,11 @@ #include "ios_image_load.h" #define LOG(x) std::cerr -#define CHECK(x) if (!(x)) { LOG(ERROR) << #x << "failed"; exit(1); } +#define CHECK(x) \ + if (!(x)) { \ + LOG(ERROR) << #x << "failed"; \ + exit(1); \ + } NSString* RunInferenceOnImage(); @@ -49,15 +53,12 @@ NSString* RunInferenceOnImage(); // Returns the top N confidence values over threshold in the provided vector, // sorted by confidence in descending order. -static void GetTopN( - const float* prediction, - const int prediction_size, - const int num_results, const float threshold, - std::vector<std::pair<float, int> >* top_results) { +static void GetTopN(const float* prediction, const int prediction_size, const int num_results, + const float threshold, std::vector<std::pair<float, int> >* top_results) { // Will contain top N results in ascending order. - std::priority_queue<std::pair<float, int>, - std::vector<std::pair<float, int> >, - std::greater<std::pair<float, int> > > top_result_pq; + std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int> >, + std::greater<std::pair<float, int> > > + top_result_pq; const long count = prediction_size; for (int i = 0; i < count; ++i) { @@ -88,8 +89,8 @@ static void GetTopN( NSString* FilePathForResourceName(NSString* name, NSString* extension) { NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; } return file_path; } @@ -102,7 +103,8 @@ NSString* RunInferenceOnImage() { NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite"); - std::unique_ptr<tflite::FlatBufferModel> model(tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); + std::unique_ptr<tflite::FlatBufferModel> model( + tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); if (!model) { LOG(FATAL) << "Failed to mmap model " << graph; } @@ -143,7 +145,7 @@ NSString* RunInferenceOnImage() { std::ifstream t; t.open([labels_path UTF8String]); std::string line; - while(t){ + while (t) { std::getline(t, line); label_strings.push_back(line); } @@ -154,7 +156,8 @@ NSString* RunInferenceOnImage() { int image_width; int image_height; int image_channels; - std::vector<uint8_t> image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); + std::vector<uint8_t> image_data = + LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); const int wanted_width = 224; const int wanted_height = 224; const int wanted_channels = 3; @@ -212,8 +215,7 @@ NSString* RunInferenceOnImage() { std::string predictions = ss.str(); NSString* result = @""; - result = [NSString stringWithFormat: @"%@ - %s", result, - predictions.c_str()]; - + result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()]; + return result; } diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h index 7287d0d63d..98934ce41d 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h @@ -17,9 +17,7 @@ #include <vector> -std::vector<uint8_t> LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); +std::vector<uint8_t> LoadImageFromFile(const char* file_name, int* out_width, + int* out_height, int* out_channels); #endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm index 789522d2a9..cb0fe1a765 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm @@ -14,17 +14,16 @@ #include "ios_image_load.h" -#include <stdlib.h> -#include <string.h> #include <assert.h> #include <stdio.h> +#include <stdlib.h> +#include <string.h> #import <CoreImage/CoreImage.h> #import <ImageIO/ImageIO.h> -std::vector<uint8_t> LoadImageFromFile(const char* file_name, - int* out_width, int* out_height, - int* out_channels) { +std::vector<uint8_t> LoadImageFromFile(const char* file_name, int* out_width, int* out_height, + int* out_channels) { FILE* file_handle = fopen(file_name, "rb"); fseek(file_handle, 0, SEEK_END); const size_t bytes_in_file = ftell(file_handle); @@ -32,11 +31,10 @@ std::vector<uint8_t> LoadImageFromFile(const char* file_name, std::vector<uint8_t> file_data(bytes_in_file); fread(file_data.data(), 1, bytes_in_file, file_handle); fclose(file_handle); - CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), - bytes_in_file, - kCFAllocatorNull); - CGDataProviderRef image_provider = - CGDataProviderCreateWithCFData(file_data_ref); + + CFDataRef file_data_ref = + CFDataCreateWithBytesNoCopy(NULL, file_data.data(), bytes_in_file, kCFAllocatorNull); + CGDataProviderRef image_provider = CGDataProviderCreateWithCFData(file_data_ref); const char* suffix = strrchr(file_name, '.'); if (!suffix || suffix == file_name) { @@ -44,12 +42,10 @@ std::vector<uint8_t> LoadImageFromFile(const char* file_name, } CGImageRef image; if (strcasecmp(suffix, ".png") == 0) { - image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else if ((strcasecmp(suffix, ".jpg") == 0) || - (strcasecmp(suffix, ".jpeg") == 0)) { - image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); + image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); + } else if ((strcasecmp(suffix, ".jpg") == 0) || (strcasecmp(suffix, ".jpeg") == 0)) { + image = + CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); } else { CFRelease(image_provider); CFRelease(file_data_ref); @@ -68,9 +64,10 @@ std::vector<uint8_t> LoadImageFromFile(const char* file_name, const int bytes_in_image = (bytes_per_row * height); std::vector<uint8_t> result(bytes_in_image); const int bits_per_component = 8; - CGContextRef context = CGBitmapContextCreate(result.data(), width, height, - bits_per_component, bytes_per_row, color_space, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + + CGContextRef context = + CGBitmapContextCreate(result.data(), width, height, bits_per_component, bytes_per_row, + color_space, kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); CGColorSpaceRelease(color_space); CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); CGContextRelease(context); diff --git a/tensorflow/contrib/lite/examples/ios/simple/main.mm b/tensorflow/contrib/lite/examples/ios/simple/main.mm index d70550a730..05cb55ddd7 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/main.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/main.mm @@ -14,7 +14,7 @@ #import <UIKit/UIKit.h> -int main(int argc, char * argv[]) { +int main(int argc, char *argv[]) { @autoreleasepool { NSString *delegateClassName = @"AppDelegate"; return UIApplicationMain(argc, argv, nil, delegateClassName); diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/ios_makefile.inc index bcff7ed988..345ed26212 100644 --- a/tensorflow/contrib/lite/ios_makefile.inc +++ b/tensorflow/contrib/lite/ios_makefile.inc @@ -1,47 +1,31 @@ -# Settings for iOS. -ifeq ($(TARGET), IOS) - BUILD_FOR_IOS_SIMULATOR := false - ifeq ($(IOS_ARCH), x86_64) - BUILD_FOR_IOS_SIMULATOR := true - endif - ifeq ($(IOS_ARCH), i386) - BUILD_FOR_IOS_SIMULATOR := true - endif - ifeq ($(BUILD_FOR_IOS_SIMULATOR), true) - IPHONEOS_PLATFORM := $(shell xcrun --sdk iphonesimulator \ - --show-sdk-platform-path) - IPHONEOS_SYSROOT := $(shell xcrun --sdk iphonesimulator \ - --show-sdk-path) - else - IPHONEOS_PLATFORM := $(shell xcrun --sdk iphoneos --show-sdk-platform-path) - IPHONEOS_SYSROOT := $(shell xcrun --sdk iphoneos --show-sdk-path) - endif - IOS_SDK_VERSION := $(shell xcrun --sdk iphoneos --show-sdk-version) - MIN_SDK_VERSION := 9.0 - # Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64. - IOS_ARCH := x86_64 - CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ - -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ - -fembed-bitcode \ - -Wno-c++11-narrowing \ - -mno-thumb \ - -fno-exceptions \ - -isysroot \ - ${IPHONEOS_SYSROOT} \ - -arch $(IOS_ARCH) \ - -O3 - CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ - -fembed-bitcode \ - -mno-thumb \ - -isysroot \ - ${IPHONEOS_SYSROOT} \ - -arch $(IOS_ARCH) \ - -O3 - LDFLAGS := -fembed-bitcode \ - -miphoneos-version-min=${MIN_SDK_VERSION} \ - -arch $(IOS_ARCH) - OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/ - LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/ - BINDIR := $(BINDIR)ios_$(IOS_ARCH)/ - DEPDIR := $(DEPDIR)ios_$(IOS_ARCH)/ -endif +#Settings for iOS. +ifeq($(TARGET), IOS) BUILD_FOR_IOS_SIMULATOR + : = false ifeq($(IOS_ARCH), x86_64) BUILD_FOR_IOS_SIMULATOR + : = true endif ifeq($(IOS_ARCH), i386) BUILD_FOR_IOS_SIMULATOR + : = true endif ifeq($(BUILD_FOR_IOS_SIMULATOR), true) IPHONEOS_PLATFORM + : = $(shell xcrun-- sdk iphonesimulator-- show - sdk - platform - + path) IPHONEOS_SYSROOT + : = $(shell xcrun-- sdk iphonesimulator-- show - sdk - + path) else IPHONEOS_PLATFORM + : = $(shell xcrun-- sdk iphoneos-- show - sdk - platform - + path) IPHONEOS_SYSROOT + : = $(shell xcrun-- sdk iphoneos-- show - sdk - path) endif IOS_SDK_VERSION + : = $(shell xcrun-- sdk iphoneos-- show - sdk - version) MIN_SDK_VERSION + : = 9.0 +#Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64. + IOS_ARCH + : = x86_64 CXXFLAGS + += -miphoneos - version + - min = $(MIN_SDK_VERSION) - DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK + - fembed - bitcode - Wno - c++ 11 - narrowing - mno - thumb + - fno - exceptions + - isysroot ${IPHONEOS_SYSROOT} - arch $(IOS_ARCH) - O3 CCFLAGS + += -miphoneos - version + - min = $(MIN_SDK_VERSION) - fembed - bitcode - mno - thumb + - isysroot ${IPHONEOS_SYSROOT} - arch $(IOS_ARCH) - + O3 LDFLAGS + : = -fembed - bitcode - miphoneos - version + - min = ${MIN_SDK_VERSION} - arch $(IOS_ARCH) OBJDIR + : = $(OBJDIR) ios_$(IOS_ARCH) / LIBDIR + : = $(LIBDIR) ios_$(IOS_ARCH) / BINDIR + : = $(BINDIR) ios_$(IOS_ARCH) / DEPDIR : = $(DEPDIR) ios_$(IOS_ARCH) / endif diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md index 71b633c577..5d13a798e2 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -8,7 +8,12 @@ It's easiest with Android Studio. - You'll need at least SDK version 23. + - Make sure to install the latest version of Bazel. Some distributions + ship with Bazel 0.5.4, which is too old. - Bazel requires Android Build Tools `26.0.1` or higher. + - **Bazel is incompatible with NDK revisions 15 and above,** with revision + 16 being a compile-breaking change. [Download an older version manually + instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites) - You also need to install the Android Support Repository, available through Android Studio under `Android SDK Manager -> SDK Tools -> Android Support Repository`. @@ -19,7 +24,8 @@ - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that you have installed. - By default, Android Studio will install the SDK to `~/Android/Sdk` and - the NDK to `~/Android/Sdk/ndk-bundle`. + the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual + download until Bazel supports NDK 16. See bullet points under (1)). 2. Build the app with Bazel. The demo needs C++11: diff --git a/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc b/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc index 30d89a1354..bf95b313f3 100644 --- a/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc +++ b/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Unit test for speech TERSE AM model using TFLite Ops. +// Unit test for speech ASR AM model using TFLite Ops. #include <string.h> @@ -45,10 +45,10 @@ constexpr int kLstmLayer5OutputStateTensor = 103; constexpr int kLstmLayer5CellStateTensor = 104; constexpr int kModelOutputTensor = 109; -TEST(SpeechTerseAm, RandomIOTest) { +TEST(SpeechAsrAm, RandomIOTest) { // Read the model. string tflite_file_path = - file::JoinPath(TestDataPath(), "speech_terse_am_model.tflite"); + file::JoinPath(TestDataPath(), "speech_asr_am_model.tflite"); auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); CHECK(model) << "Failed to mmap model " << tflite_file_path; @@ -62,13 +62,13 @@ TEST(SpeechTerseAm, RandomIOTest) { // Load the input frames. Frames input_frames; const string input_file_path = - file::JoinPath(TestDataPath(), "speech_terse_am_model_in.csv"); + file::JoinPath(TestDataPath(), "speech_asr_am_model_in.csv"); ReadFrames(input_file_path, &input_frames); // Load the golden output results. Frames output_frames; const string output_file_path = - file::JoinPath(TestDataPath(), "speech_terse_am_model_out.csv"); + file::JoinPath(TestDataPath(), "speech_asr_am_model_out.csv"); ReadFrames(output_file_path, &output_frames); const int speech_batch_size = diff --git a/tensorflow/contrib/lite/models/speech_terse_lm_model_test.cc b/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc index 04c54ffb22..53f2b66da4 100644 --- a/tensorflow/contrib/lite/models/speech_terse_lm_model_test.cc +++ b/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc @@ -59,10 +59,10 @@ static void ClearLstmStates(Interpreter* interpreter) { interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); } -TEST(SpeechTerseLm, EndToEndTest) { +TEST(SpeechAsrLm, EndToEndTest) { // Read the model. string tflite_file_path = - file::JoinPath(TestDataPath(), "speech_terse_lm_model.tflite"); + file::JoinPath(TestDataPath(), "speech_asr_lm_model.tflite"); auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); CHECK(model) << "Failed to mmap model " << tflite_file_path; @@ -76,13 +76,13 @@ TEST(SpeechTerseLm, EndToEndTest) { // Load the input frames. Frames input_frames; const string input_file_path = - file::JoinPath(TestDataPath(), "speech_terse_lm_model_in.csv"); + file::JoinPath(TestDataPath(), "speech_asr_lm_model_in.csv"); ReadFrames(input_file_path, &input_frames); // Load the golden output results. Frames output_frames; const string output_file_path = - file::JoinPath(TestDataPath(), "speech_terse_lm_model_out.csv"); + file::JoinPath(TestDataPath(), "speech_asr_lm_model_out.csv"); ReadFrames(output_file_path, &output_frames); CHECK_EQ(interpreter->tensor(kModelInput1Tensor)->dims->size, 1); diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/README.md b/tensorflow/contrib/lite/models/testdata/g3doc/README.md index c9630c00db..46b24248f0 100644 --- a/tensorflow/contrib/lite/models/testdata/g3doc/README.md +++ b/tensorflow/contrib/lite/models/testdata/g3doc/README.md @@ -86,25 +86,34 @@ same input. ### Models: -[Speech hotword model (Svdf rank=1)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank1_2017_11_14.tflite) +[Speech hotword model (Svdf +rank=1)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank1_2017_11_14.tflite) -[Speech hotword model (Svdf rank=2)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank2_2017_11_14.tflite) +[Speech hotword model (Svdf +rank=2)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank2_2017_11_14.tflite) -[Speaker-id model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_speakerid_model_2017_11_14.tflite) +[Speaker-id +model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_speakerid_model_2017_11_14.tflite) -[TTS model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_tts_model_2017_11_14.tflite) +[TTS +model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_tts_model_2017_11_14.tflite) -[ASR AM model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_terse_am_model_2017_11_14.tflite) +[ASR AM +model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_terse_am_model_2017_11_14.tflite) ### Test benches -[Speech hotword model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_hotword_model_test.cc) +[Speech hotword model +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_hotword_model_test.cc) -[Speaker-id model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc) +[Speaker-id model +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc) -[TTS model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_tts_model_test.cc) +[TTS model +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_tts_model_test.cc) -[ASR AM model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc) +[ASR AM model +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc) ## Android Support The models have been tested on Android phones, using the following tests: diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h index b78e958e7f..bdb5e01538 100644 --- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h +++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h @@ -1454,9 +1454,9 @@ inline int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) { * {@link ANeuralNetworksExecution_setOutputFromMemory} and * {@link ANeuralNetworksExecution_setOperandValue}. * - * To build a model that can accommodate inputs of various sizes, as you may want - * to do for a CNN, set the size of the dimensions that will vary at run time to - * 0. If you do so, provide the full dimensions when calling + * To build a model that can accommodate inputs of various sizes, as you may + * want to do for a CNN, set the size of the dimensions that will vary at run + * time to 0. If you do so, provide the full dimensions when calling * {@link ANeuralNetworksExecution_setInput} or {@link * ANeuralNetworksExecution_setInputFromMemory}. * diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 0fd70f842b..982ea90f2b 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -50,7 +50,7 @@ GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT # to protect against crashes. However, it breaks some dependent targets because # it forces us to depend on an external py_binary. The experimental API doesn't # have that drawback. -EXPERIMENTAL_USE_TOCO_API_DIRECTLY = True +EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False # Find the toco_from_protos binary using the resource loader if using from # bazel, otherwise we are in a pip where console_scripts already has diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc index f80949b23e..6ae3ab5729 100644 --- a/tensorflow/contrib/lite/tools/benchmark_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark_model.cc @@ -31,7 +31,12 @@ void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); #endif #define LOG(x) std::cerr -#define CHECK(x) if (!(x)) { LOG(ERROR) << #x << "failed"; exit(1); } + +#define CHECK(x) \ + if (!(x)) { \ + LOG(ERROR) << #x << "failed"; \ + exit(1); \ + } namespace tensorflow { namespace benchmark_tflite_model { diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h index 8206a5481d..be60cf476d 100644 --- a/tensorflow/contrib/lite/tools/mutable_op_resolver.h +++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h @@ -20,15 +20,14 @@ limitations under the License. #include "tensorflow/contrib/lite/model.h" // Needed to resolve unordered_set hash on older compilers. -namespace std -{ -template<> - struct hash<tflite::BuiltinOperator> { - size_t operator()(const tflite::BuiltinOperator &op) const { - return std::hash<int>()(op); - } - }; -} +namespace std { +template <> +struct hash<tflite::BuiltinOperator> { + size_t operator()(const tflite::BuiltinOperator& op) const { + return std::hash<int>()(op); + } +}; +} // namespace std namespace tflite { diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc index 31a35b0d53..913935b382 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc @@ -258,9 +258,37 @@ NcclManager::Communicator* NcclManager::GetCommunicator( devices[i] = collective->participants[i]->gpu_device_id; } + int device_count = num_devices; +#if NCCL_MAJOR >= 2 + // NCCL2 prevents InitAll for more communicators than devices (but doesn't + // check that device ids are unique). Work around it by initializing each + // rank individually. + cudaGetDeviceCount(&device_count); +#endif std::vector<ncclComm_t> nccl_comms(num_devices); - auto result = ncclCommInitAll(nccl_comms.data(), num_devices, devices.data()); - CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result); + if (num_devices <= device_count) { + auto result = + ncclCommInitAll(nccl_comms.data(), num_devices, devices.data()); + CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result); + } else { + int savedDevice = 0; + CHECK_EQ(cudaGetDevice(&savedDevice), cudaSuccess); + ncclUniqueId commId; + ncclGetUniqueId(&commId); +#if NCCL_MAJOR >= 2 + CHECK_EQ(ncclGroupStart(), ncclSuccess); +#endif + for (int rank = 0; rank < num_devices; ++rank) { + cudaSetDevice(devices[rank]); + auto result = + ncclCommInitRank(nccl_comms.data() + rank, num_devices, commId, rank); + CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result); + } +#if NCCL_MAJOR >= 2 + CHECK_EQ(ncclGroupEnd(), ncclSuccess); +#endif + cudaSetDevice(savedDevice); + } for (int rank = 0; rank < num_devices; ++rank) { members[rank].nccl_comm = nccl_comms[rank]; } diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc index 505c4b0d71..abafe4b407 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc @@ -30,6 +30,8 @@ namespace tensorflow { static std::vector<BaseGPUDevice*> GetGPUDevices() { std::vector<Device*> devices; SessionOptions session_options; + session_options.config.mutable_gpu_options() + ->set_per_process_gpu_memory_fraction(0.1); session_options.env = Env::Default(); Status s = DeviceFactory::GetFactory(DEVICE_GPU) ->AddDevices(session_options, "", &devices); diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index f783179f61..9e6af5232f 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -31,6 +31,7 @@ from tensorflow.contrib.summary.summary_ops import audio from tensorflow.contrib.summary.summary_ops import create_summary_db_writer from tensorflow.contrib.summary.summary_ops import create_summary_file_writer from tensorflow.contrib.summary.summary_ops import eval_dir +from tensorflow.contrib.summary.summary_ops import flush from tensorflow.contrib.summary.summary_ops import generic from tensorflow.contrib.summary.summary_ops import graph from tensorflow.contrib.summary.summary_ops import histogram diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index 8e37987cb7..de6f2cd79f 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -516,6 +516,27 @@ def import_event(tensor, name=None): context.context().summary_writer_resource, tensor, name=name) +def flush(writer=None, name=None): + """Forces summary writer to send any buffered data to storage. + + This operation blocks until that finishes. + + Args: + writer: The @{tf.contrib.summary.SummaryWriter} resource to flush. + The thread default will be used if this parameter is None. + Otherwise a @{tf.no_op} is returned. + name: A name for the operation (optional). + + Returns: + The created @{tf.Operation}. + """ + if writer is None: + writer = context.context().summary_writer_resource + if writer is None: + return control_flow_ops.no_op() + return gen_summary_ops.flush_summary_writer(writer, name=name) + + def eval_dir(model_dir, name=None): """Construct a logdir for an eval summary writer.""" return os.path.join(model_dir, "eval" if not name else "eval_" + name) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index d20300c858..54433deb28 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -108,6 +108,33 @@ class TargetTest(test_util.TensorFlowTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') + def testMaxQueue(self): + logs = tempfile.mkdtemp() + with summary_ops.create_summary_file_writer( + logs, max_queue=2, flush_millis=999999, + name='lol').as_default(), summary_ops.always_record_summaries(): + get_total = lambda: len(summary_test_util.events_from_logdir(logs)) + # Note: First tf.Event is always file_version. + self.assertEqual(1, get_total()) + summary_ops.scalar('scalar', 2.0, step=1) + self.assertEqual(1, get_total()) + summary_ops.scalar('scalar', 2.0, step=2) + self.assertEqual(3, get_total()) + + def testFlush(self): + logs = tempfile.mkdtemp() + with summary_ops.create_summary_file_writer( + logs, max_queue=999999, flush_millis=999999, + name='lol').as_default(), summary_ops.always_record_summaries(): + get_total = lambda: len(summary_test_util.events_from_logdir(logs)) + # Note: First tf.Event is always file_version. + self.assertEqual(1, get_total()) + summary_ops.scalar('scalar', 2.0, step=1) + summary_ops.scalar('scalar', 2.0, step=2) + self.assertEqual(1, get_total()) + summary_ops.flush() + self.assertEqual(3, get_total()) + class DbTest(summary_test_util.SummaryDbTest): diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py index 94767c8df2..915820e05b 100644 --- a/tensorflow/contrib/summary/summary_test_util.py +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -83,7 +83,7 @@ def events_from_logdir(logdir): """ assert gfile.Exists(logdir) files = gfile.ListDirectory(logdir) - assert len(files) == 1, "Found not exactly one file in logdir: %s" % files + assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files return events_from_file(os.path.join(logdir, files[0])) diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index f542d94139..a34c7f91f2 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -32,21 +32,6 @@ cc_library( ) py_library( - name = "tpu_test_util", - srcs = ["python/tpu/test_util.py"], - srcs_version = "PY2AND3", - deps = [ - ":tpu_lib", - ":tpu_py", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:session", - "//tensorflow/python:variables", - ], -) - -py_library( name = "tpu_estimator", srcs = [ "python/tpu/tpu_config.py", diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc index cbbd19800e..d389050e67 100644 --- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc +++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc @@ -22,7 +22,7 @@ namespace tensorflow { REGISTER_OP("CrossReplicaSum") .Input("input: T") .Output("output: T") - .Attr("T: {float}") + .Attr("T: {bfloat16, float}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( An Op to sum inputs across replicated TPU instances. Each diff --git a/tensorflow/contrib/tpu/python/tpu/test_util.py b/tensorflow/contrib/tpu/python/tpu/test_util.py deleted file mode 100644 index a5d4ff9722..0000000000 --- a/tensorflow/contrib/tpu/python/tpu/test_util.py +++ /dev/null @@ -1,296 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# =================================================================== -"""Utilities to ease testing on TPU devices.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os.path -import pickle -import tempfile - -import numpy as np - -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_config -from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session as tf_session -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saver as tf_saver - - -def has_tpu(): - """Check if a TPU device is available. - - Device enumeration via `device_lib` currently fails for TPU systems. - (http://b/68333779). To work around this, we determine the existence of a - TPU by a successful call to `initialize_system`. - - Returns: - boolean, True if a TPU device is available, otherwise False. - """ - - def _check(): - with tf_session.Session() as sess: - sess.run(tpu.initialize_system()) - sess.run(tpu.shutdown_system()) - - try: - _check() - return True - except errors.OpError as _: - return False - - -def _available_devices(): - devices = ["cpu"] - if not test_util.gpu_device_name(): - devices.append("gpu") - - if has_tpu(): - devices.append("tpu") - - return tuple(devices) - - -def copy_dir(src, tgt): - """Copy src to tgt.""" - gfile.MakeDirs(tgt) - seen_dirs = set() - for dirname, _, files in gfile.Walk(src): - for f in files: - src_f = os.path.join(dirname, f) - tgt_f = src_f.replace(src, tgt) - tgt_d = os.path.dirname(tgt_f) - if tgt_d not in seen_dirs: - gfile.MkDir(tgt_d) - seen_dirs.add(tgt_d) - gfile.Copy(src_f, tgt_f, overwrite=True) - - -def compare_model(model_fn, - input_fn, - params, - master="local", - temp_dir=None, - num_shards=2, - tolerance=1e-4): - """Compare the results of running `model_fn` on the TPU and CPU.""" - if not temp_dir: - temp_dir = tempfile.mkdtemp() - - cpu_model_dir = "%s/cpu-model" % temp_dir - tpu_model_dir = "%s/tpu-model" % temp_dir - initial_model_dir = "%s/initial-model" % temp_dir - - logging.info("Checkpoints and weights will be written to %s", temp_dir) - - num_steps = 1 - - def _model_adapter(features, labels, mode, params): - """Run users model function with random seeds fixed to known values.""" - random_seed.set_random_seed(0) - np.random.seed(0) - return model_fn(features, labels, mode, params) - - def _input_adapter(params): - random_seed.set_random_seed(0) - np.random.seed(0) - return input_fn(params) - - def _make_run_config(model_dir): - return tpu_config.RunConfig( - master=master, - model_dir=model_dir, - save_checkpoints_secs=10000, - session_config=config_pb2.ConfigProto( - allow_soft_placement=True, log_device_placement=False), - tpu_config=tpu_config.TPUConfig( - iterations_per_loop=num_steps, - num_shards=num_shards, - ), - ) - - def _make_estimator(use_tpu, model_dir): - return tpu_estimator.TPUEstimator( - model_fn=_model_adapter, - use_tpu=use_tpu, - config=_make_run_config(model_dir), - train_batch_size=num_shards, - params=dict(params, use_tpu=use_tpu), - ) - - def _extract_weights(checkpoint): - """Extract model weights from the given checkpoint file.""" - weights = {} - graph = ops.Graph() - with graph.as_default(): - features, labels = _input_adapter(dict(params, batch_size=num_shards)) - model_fn( - features, labels, - params=dict(params, use_tpu=False), - mode=model_fn_lib.ModeKeys.TRAIN) - saver = tf_saver.Saver() - with tf_session.Session(graph=graph) as sess: - saver.restore(sess, checkpoint) - all_vars = [] - all_vars.extend(graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) - all_vars.extend(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) - all_vars.extend(graph.get_collection(ops.GraphKeys.MODEL_VARIABLES)) - - for var in all_vars: - weights[var.name] = sess.run(var) - return weights - - def _run_step(use_tpu, model_dir): - """Create an estimator and run a single step on the given device.""" - tf_session.Session.reset(target=master) - - logging.info("Running step. TPU=%d. model_dir=%s", use_tpu, model_dir) - est = _make_estimator(use_tpu=use_tpu, model_dir=model_dir) - est.train(input_fn=_input_adapter, steps=num_steps) - weights = _extract_weights(est.latest_checkpoint()) - with gfile.Open(os.path.join(temp_dir, "tpu-%d.weights" % use_tpu), - "wb") as f: - f.write(pickle.dumps(weights)) - return weights - - # initialize models to the same weights by running a single step on the CPU - _run_step(use_tpu=False, model_dir=initial_model_dir) - - copy_dir(initial_model_dir, cpu_model_dir) - copy_dir(initial_model_dir, tpu_model_dir) - - cpu_weights = _run_step(use_tpu=False, model_dir=cpu_model_dir) - tpu_weights = _run_step(use_tpu=True, model_dir=tpu_model_dir) - - bad_weights = False - for k in cpu_weights: - if k not in tpu_weights: - raise KeyError("Missing weight %s from TPU checkpoint.", k) - - if not np.allclose( - cpu_weights[k], tpu_weights[k], rtol=tolerance, atol=tolerance): - bad_weights = True - logging.error("Weights for layer %s have diverged.", k) - - if bad_weights: - raise ValueError("Some weights have diverged. Output pickle files have " - "been written to %s for inspection." % temp_dir) - - -class TPUTestCase(test_util.TensorFlowTestCase): - """Adds helpers for testing on TPU devices to `TensorFlowTestCase`. - - Example usage: - - ``` - def model_fn(features): - return tf.reduce_sum(features * 2) - - class ModelTests(test_util.TPUTestCase): - def test_sum(self): - v = np.random.randn(10, 10).astype("float32") - self.assert_device_output(model_fn, [v], (v*2).sum(), - devices=("cpu", "tpu")) - ``` - """ - - def __init__(self, methodName="runTest"): # pylint: disable=invalid-name - super(TPUTestCase, self).__init__(methodName) - self._available_devices = _available_devices() - - def run_on_device(self, model_fn, model_inputs, device): - """Runs `model_fn` on the given device. - - Raises an exception if no such device is available. `model_fn` should - return one or more tensors as a list or tuple. - - Args: - model_fn: Function returning one or more tensors. - model_inputs: An iterable of Numpy arrays or scalars. - These will be passed as arguments to `model_fn`. - device: Device to run on. One of ("tpu", "gpu", "cpu"). - - Returns: - Output from the model function. - """ - - def _make_placeholders(): - return dict([(gen_array_ops.placeholder_with_default(v, v.shape), v) - for v in model_inputs]) - - if device == "tpu": - with self.test_session(graph=ops.Graph()) as sess: - placeholders = _make_placeholders() - tpu_computation = tpu.rewrite(model_fn, placeholders.keys()) - sess.run(tpu.initialize_system()) - sess.run(variables.global_variables_initializer()) - result = sess.run(tpu_computation, placeholders) - sess.run(tpu.shutdown_system()) - # TODO(b/36891278): supports non-flat returns lists in tpu.rewrite(). - if len(result) == 1: - return result[0] - return result - elif device == "gpu": - with self.test_session(graph=ops.Graph(), use_gpu=True) as sess: - placeholders = _make_placeholders() - sess.run(variables.global_variables_initializer()) - return sess.run(model_fn(placeholders.keys()), placeholders) - elif device == "cpu": - # TODO(power) -- will this interact poorly with cached GPU sessions? - with self.test_session(graph=ops.Graph(), use_gpu=False) as sess: - placeholders = _make_placeholders() - sess.run(variables.global_variables_initializer()) - return sess.run(model_fn(placeholders.keys()), placeholders) - - def _compare_values(self, actual_outputs, expected_outputs): - if isinstance(expected_outputs, (list, tuple)): - for a, b in zip(actual_outputs, expected_outputs): - self.assertAllCloseAccordingToType(a, b) - else: - self.assertAllCloseAccordingToType(actual_outputs, expected_outputs) - - def assert_device_output(self, - model_fn, - model_inputs, - expected_outputs, - devices=("cpu", "gpu", "tpu")): - """Run `model_fn` on the given devices. - - Results are compared via `assertAllCloseAccordingToType`. - - Args: - model_fn: Function returning one or more tensors - model_inputs: Numpy arrays or scalars passed as arguments to model_fn - expected_outputs: Numpy arrays or scalars to compare against. - devices: Set of devices to run on. If a device is not available, tests - will be skipped for that device. - """ - devices = set(devices).intersection(self._available_devices) - - for device in devices: - device_out = self.run_on_device(model_fn, model_inputs, device=device) - self._compare_values(device_out, expected_outputs) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index fe17664d7f..84a4208be3 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -514,6 +514,7 @@ class _InfeedThreadController(_InfeedOutfeedThreadBaseController): exc_info=1 ) time.sleep(120) + logging.error('Closing the failed session.') session.close() def join(self): diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index bd7617fa96..5bcb87d2d1 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1016,7 +1016,7 @@ filegroup( cc_library( name = "android_tensorflow_lib_lite", srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts() + if_not_android_mips_and_mips64(["-Os"]), + copts = tf_copts(android_optimization_level_override = None), linkopts = ["-lz"], tags = [ "manual", @@ -1106,8 +1106,7 @@ cc_library( cc_library( name = "android_tensorflow_lib_selective_registration", srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts() + [ - "-Os", + copts = tf_copts(android_optimization_level_override = None) + [ "-DSUPPORT_SELECTIVE_REGISTRATION", ], tags = [ @@ -1129,8 +1128,7 @@ cc_library( cc_library( name = "android_tensorflow_lib_selective_registration_nortti", srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts() + tf_opts_nortti_if_android() + [ - "-Os", + copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ "-DSUPPORT_SELECTIVE_REGISTRATION", ], tags = [ @@ -1210,7 +1208,7 @@ cc_library( "framework/tensor_testutil.h", "util/reporter.h", ], - copts = tf_copts() + ["-Os"], + copts = tf_copts(android_optimization_level_override = None), tags = [ "manual", "notap", diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt index 6522ce976f..070d6adb97 100644 --- a/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt @@ -26,7 +26,7 @@ END description: <<END 1-D tensor of length 4. The stride of the sliding window for each dimension of `input`. The dimension order is determined by the value of - `data_format`, see below for details. +`data_format`, see below for details. END } attr { @@ -45,6 +45,16 @@ Alternatively, the format could be "NCHW", the data storage order of: [batch, channels, height, width]. END } + attr { + name: "dilations" + description: <<END +1-D tensor of length 4. The dilation factor for each dimension of +`input`. If set to k > 1, there will be k-1 skipped cells between each +filter element on that dimension. The dimension order is determined by the +value of `data_format`, see above for details. Dilations in the batch and +depth dimensions must be 1. +END + } summary: "Computes a 2-D convolution given 4-D `input` and `filter` tensors." description: <<END Given an input tensor of shape `[batch, in_height, in_width, in_channels]` diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt index 4ea3374dbb..ff2d9d71db 100644 --- a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt @@ -53,5 +53,15 @@ Alternatively, the format could be "NCHW", the data storage order of: [batch, in_channels, in_height, in_width]. END } + attr { + name: "dilations" + description: <<END +1-D tensor of length 4. The dilation factor for each dimension of +`input`. If set to k > 1, there will be k-1 skipped cells between each filter +element on that dimension. The dimension order is determined by the value of +`data_format`, see above for details. Dilations in the batch and depth +dimensions must be 1. +END + } summary: "Computes the gradients of convolution with respect to the filter." } diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt index 4420073e38..2de38b4263 100644 --- a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt @@ -52,5 +52,15 @@ Alternatively, the format could be "NCHW", the data storage order of: [batch, in_channels, in_height, in_width]. END } + attr { + name: "dilations" + description: <<END +1-D tensor of length 4. The dilation factor for each dimension of +`input`. If set to k > 1, there will be k-1 skipped cells between each filter +element on that dimension. The dimension order is determined by the value of +`data_format`, see above for details. Dilations in the batch and depth +dimensions must be 1. +END + } summary: "Computes the gradients of convolution with respect to the input." } diff --git a/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt index 8f3cd4493c..d26564097e 100644 --- a/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt @@ -36,6 +36,16 @@ Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width]. END } + attr { + name: "dilations" + description: <<END +1-D tensor of length 5. The dilation factor for each dimension of +`input`. If set to k > 1, there will be k-1 skipped cells between each +filter element on that dimension. The dimension order is determined by the +value of `data_format`, see above for details. Dilations in the batch and +depth dimensions must be 1. +END + } summary: "Computes a 3-D convolution given 5-D `input` and `filter` tensors." description: <<END In signal processing, cross-correlation is a measure of similarity of diff --git a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt index 6f9b917237..937c9c8ead 100644 --- a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt @@ -45,5 +45,15 @@ Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width]. END } + attr { + name: "dilations" + description: <<END +1-D tensor of length 5. The dilation factor for each dimension of +`input`. If set to k > 1, there will be k-1 skipped cells between each +filter element on that dimension. The dimension order is determined by the +value of `data_format`, see above for details. Dilations in the batch and +depth dimensions must be 1. +END + } summary: "Computes the gradients of 3-D convolution with respect to the filter." } diff --git a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt index 19aba156d5..414e418dc5 100644 --- a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt @@ -45,5 +45,15 @@ Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width]. END } + attr { + name: "dilations" + description: <<END +1-D tensor of length 5. The dilation factor for each dimension of +`input`. If set to k > 1, there will be k-1 skipped cells between each +filter element on that dimension. The dimension order is determined by the +value of `data_format`, see above for details. Dilations in the batch and +depth dimensions must be 1. +END + } summary: "Computes the gradients of 3-D convolution with respect to the input." } diff --git a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt index cc10ebe923..3c313f7be6 100644 --- a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt @@ -23,6 +23,16 @@ Alternatively, the format could be "NCHW", the data storage order of: [batch, channels, height, width]. END } + attr { + name: "dilations" + description: <<END +1-D tensor of length 4. The dilation factor for each dimension of +`input`. If set to k > 1, there will be k-1 skipped cells between each filter +element on that dimension. The dimension order is determined by the value of +`data_format`, see above for details. Dilations in the batch and depth +dimensions must be 1. +END + } summary: "Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors." description: <<END Given an input tensor of shape `[batch, in_height, in_width, in_channels]` diff --git a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt index 9126be2afa..e66aa3b707 100644 --- a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt @@ -56,5 +56,15 @@ Alternatively, the format could be "NCHW", the data storage order of: [batch, channels, height, width]. END } + attr { + name: "dilations" + description: <<END +1-D tensor of length 4. The dilation factor for each dimension of +`input`. If set to k > 1, there will be k-1 skipped cells between each filter +element on that dimension. The dimension order is determined by the value of +`data_format`, see above for details. Dilations in the batch and depth +dimensions must be 1. +END + } summary: "Computes the gradients of depthwise convolution with respect to the filter." } diff --git a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt index f1d16858db..f501ad21b3 100644 --- a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt @@ -56,5 +56,15 @@ Alternatively, the format could be "NCHW", the data storage order of: [batch, channels, height, width]. END } + attr { + name: "dilations" + description: <<END +1-D tensor of length 4. The dilation factor for each dimension of +`input`. If set to k > 1, there will be k-1 skipped cells between each filter +element on that dimension. The dimension order is determined by the value of +`data_format`, see above for details. Dilations in the batch and depth +dimensions must be 1. +END + } summary: "Computes the gradients of depthwise convolution with respect to the input." } diff --git a/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt b/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt index 00e96c8a15..dfaa531cbc 100644 --- a/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt @@ -14,4 +14,47 @@ The `dtype` of the serialized `SparseTensor` objects. END } summary: "Deserialize `SparseTensor` objects." + description: <<END +The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where +the last dimension stores serialized `SparseTensor` objects and the other N +dimensions (N >= 0) correspond to a batch. The ranks of the original +`SparseTensor` objects must all match. When the final `SparseTensor` is +created, its rank is the rank of the incoming `SparseTensor` objects plus N; +the sparse tensors have been concatenated along new dimensions, one for each +batch. + +The output `SparseTensor` object's shape values for the original dimensions +are the max across the input `SparseTensor` objects' shape values for the +corresponding dimensions. The new dimensions match the size of the batch. + +The input `SparseTensor` objects' indices are assumed ordered in +standard lexicographic order. If this is not the case, after this +step run `SparseReorder` to restore index ordering. + +For example, if the serialized input is a `[2 x 3]` matrix representing two +original `SparseTensor` objects: + + index = [ 0] + [10] + [20] + values = [1, 2, 3] + shape = [50] + +and + + index = [ 2] + [10] + values = [4, 5] + shape = [30] + +then the final deserialized `SparseTensor` will be: + + index = [0 0] + [0 10] + [0 20] + [1 2] + [1 10] + values = [1, 2, 3, 4, 5] + shape = [2 50] +END } diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt index b19bbeab12..d18bafdce9 100644 --- a/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt @@ -55,6 +55,16 @@ END The type of padding algorithm to use. END } + attr { + name: "dilations" + description: <<END +1-D tensor of length 4. The dilation factor for each dimension of +`input`. If set to k > 1, there will be k-1 skipped cells between each +filter element on that dimension. The dimension order is determined by the +value of `data_format`, see above for details. Dilations in the batch and +depth dimensions must be 1. +END + } summary: "Computes a 2D convolution given quantized 4D input and filter tensors." description: <<END The inputs are quantized tensors where the lowest value represents the real diff --git a/tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt new file mode 100644 index 0000000000..0466b40f85 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt @@ -0,0 +1,18 @@ +op { + graph_op_name: "RandomDataset" + in_arg { + name: "seed" + description: <<END +A scalar seed for the random number generator. If either seed or +seed2 is set to be non-zero, the random number generator is seeded +by the given seed. Otherwise, a random seed is used. +END + } + in_arg { + name: "seed2" + description: <<END +A second scalar seed to avoid seed collision. +END + } + summary: "Creates a Dataset that returns pseudorandom numbers." +} diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt new file mode 100644 index 0000000000..b07ee9fda9 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt @@ -0,0 +1,69 @@ +op { + graph_op_name: "ResourceScatterNdUpdate" + in_arg { + name: "ref" + description: <<END +A resource handle. Must be from a VarHandleOp. +END + } + in_arg { + name: "indices" + description: <<END +A Tensor. Must be one of the following types: int32, int64. +A tensor of indices into ref. +END + } + in_arg { + name: "updates" + description: <<END +A Tensor. Must have the same type as ref. A tensor of updated +values to add to ref. +END + } + attr { + name: "use_locking" + description: <<END +An optional bool. Defaults to True. If True, the assignment will +be protected by a lock; otherwise the behavior is undefined, +but may exhibit less contention. +END + } + summary: "Applies sparse `updates` to individual values or slices within a given" + description: <<END +variable according to `indices`. + +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +``` + +For example, say we want to update 4 scattered elements to a rank-1 tensor to +8 elements. In Python, that update would look like this: + +```python + ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1] ,[7]]) + updates = tf.constant([9, 10, 11, 12]) + update = tf.scatter_nd_update(ref, indices, updates) + with tf.Session() as sess: + print sess.run(update) +``` + +The resulting update to ref would look like this: + + [1, 11, 3, 10, 9, 6, 7, 12] + +See @{tf.scatter_nd} for more details about how to make updates to +slices. +END +} diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc index 6e45338751..17e6209f8e 100644 --- a/tensorflow/core/framework/bfloat16_test.cc +++ b/tensorflow/core/framework/bfloat16_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -104,6 +105,17 @@ TEST(Bfloat16Test, Conversion) { } } +TEST(Bfloat16Test, Epsilon) { + EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f))); + EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) + + bfloat16(1.0f))); +} + +TEST(Bfloat16Test, Negate) { + EXPECT_EQ(-3.0f, static_cast<float>(-bfloat16(3.0f))); + EXPECT_EQ(4.5f, static_cast<float>(-bfloat16(-4.5f))); +} + static void BM_FloatToBFloat16(int iters) { testing::StopTiming(); static const int N = 32 << 20; diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index ea66863bed..036e3473b1 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -397,6 +397,15 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR( CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c)); + std::vector<int32> dilations; + TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); + + if (dilations.size() != 4) { + return errors::InvalidArgument( + "Conv2D requires the dilation attribute to contain 4 values, but got: ", + dilations.size()); + } + std::vector<int32> strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); @@ -410,6 +419,8 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { const int32 stride_rows = GetTensorDim(strides, data_format, 'H'); const int32 stride_cols = GetTensorDim(strides, data_format, 'W'); + const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H'); + const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W'); DimensionHandle batch_size_dim; DimensionHandle input_depth_dim; @@ -447,12 +458,12 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); DimensionHandle output_rows, output_cols; - TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(c, input_spatial_dims[0], - filter_rows_dim, stride_rows, - padding, &output_rows)); - TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(c, input_spatial_dims[1], - filter_cols_dim, stride_cols, - padding, &output_cols)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( + c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows, + padding, &output_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( + c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols, + padding, &output_cols)); ShapeHandle output_shape; TF_RETURN_IF_ERROR( @@ -1307,6 +1318,9 @@ Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, Status ScatterNdUpdateShape(InferenceContext* c) { ShapeHandle input_shape = c->input(0); + if (c->input_handle_shapes_and_types(0) != nullptr) { + input_shape = (*c->input_handle_shapes_and_types(0))[0].shape; + } ShapeHandle indices_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); ShapeHandle updates_shape; @@ -1361,7 +1375,9 @@ Status ScatterNdUpdateShape(InferenceContext* c) { } } - c->set_output(0, input_shape); + if (c->input_handle_shapes_and_types(0) == nullptr) { + c->set_output(0, input_shape); + } return Status::OK(); } diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index ec9746b2af..5f3e5ad457 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -423,6 +423,15 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) { .Finalize(&op.node_def)); }; + // Invalid rank for input + INFER_ERROR("must be rank 4", op, "[4,4];[2,1,1,1]"); + // Invalid rank for filter + INFER_ERROR("must be rank 4", op, "[1,4,4,1];[2,1,1]"); + + // Invalid value for strides + set_op({{1, 1, 0, 1}}, "VALID", "NHWC", "HWIO"); + INFER_ERROR("must be > 0", op, "[1,2,2,1];[1,1,1,1]"); + // 1x1 filter set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); @@ -443,11 +452,6 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) { set_op({{1, 1, 2, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]"); - // Invalid rank for input - INFER_ERROR("must be rank 4", op, "[4,4];[2,1,1,1]"); - // Invalid rank for filter - INFER_ERROR("must be rank 4", op, "[1,4,4,1];[2,1,1]"); - // Unknown dims in the critical fields lead to partial inference. INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]"); INFER_OK(op, "[1,?,4,1];[2,1,1,1]", "[d0_0,?,2,d1_3]"); @@ -538,6 +542,98 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) { INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,2,2,d1_3]"); } +TEST(CommonShapeFnsTest, Conv2DDilatedShapeTest) { + ShapeInferenceTestOp op("Conv2D"); + auto set_op = [&op](const std::vector<int32>& dilations, + const std::vector<int32>& strides, const string& padding, + const string& data_format) { + TF_CHECK_OK(NodeDefBuilder("test", "Conv2D") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("dilations", dilations) + .Attr("strides", strides) + .Attr("padding", padding) + .Attr("data_format", data_format) + .Finalize(&op.node_def)); + }; + + // Invalid rank for dilation + set_op({{1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC"); + INFER_ERROR("contain 4 values", op, "[1,2,2,1];[1,1,1,1]"); + + // Invalid value for dilation + set_op({{1, 0, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC"); + INFER_ERROR("must be >= 1", op, "[1,2,2,1];[1,1,1,1]"); + + // Tests for NHWC + // 1x1 filter, 2x1 dilations, 1x1 strides + set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); + + // 1x1 filter, 2x1 dilations, 2x1 strides + set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,4,d1_3]"); + + // 1x1 filter, 2x1 dilations, 2x2 strides + set_op({{1, 2, 1, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); + + // 3x3 filter, 2x1 dilations, 1x1 strides + set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]"); + + // 3x3 filter, 2x1 dilations, 2x1 strides + set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]"); + + // 3x3 filter, 1x2 dilations, 2x2 strides + set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,2,1,d1_3]"); + + // Tests for NCHW + // 1x1 filter, 2x1 dilations, 1x1 strides + set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]"); + + // 1x1 filter, 2x1 dilations, 2x1 strides + set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,4]"); + + // 1x1 filter, 2x1 dilations, 2x2 strides + set_op({{1, 1, 2, 1}}, {{1, 1, 2, 2}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,2]"); + + // 3x3 filter, 2x1 dilations, 1x1 strides + set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]"); + + // 3x3 filter, 2x1 dilations, 2x1 strides + set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]"); + + // 3x3 filter, 1x2 dilations, 2x2 strides + set_op({{1, 1, 1, 2}}, {{1, 1, 2, 2}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,2,1]"); + + // Some tests for "SAME" padding + + // 4x4 input, 1x1 filter, 2x1 dilations, 1x1 stride + set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC"); + INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); + + // 3x3 input, 2x2 filter, 2x2 dilations, 1x1 stride + set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC"); + INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); + + // 4x4 input, 2x2 filter, 1x2 dilations, 2x2 stride + set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "SAME", "NHWC"); + INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]"); + + // 4x4 input, 2x2 filter, 2x2 dilations, 1x1 stride + set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC"); + INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); +} + TEST(CommonShapeFnsTest, Conv3DShapeTest) { ShapeInferenceTestOp op("Conv3D"); auto set_op = [&op](const std::vector<int32>& strides, diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h index 2b080e13fd..bdd5af064b 100644 --- a/tensorflow/core/framework/numeric_types.h +++ b/tensorflow/core/framework/numeric_types.h @@ -58,7 +58,7 @@ struct bfloat16 { explicit EIGEN_DEVICE_FUNC bfloat16(const T& val) : bfloat16(static_cast<float>(val)) {} - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const { + EIGEN_DEVICE_FUNC explicit operator float() const { float result; uint16_t* q = reinterpret_cast<uint16_t*>(&result); @@ -89,6 +89,10 @@ struct bfloat16 { return static_cast<int>(float(*this)); } + EIGEN_DEVICE_FUNC explicit operator long() const { + return static_cast<long>(float(*this)); + } + EIGEN_DEVICE_FUNC explicit operator char() const { return static_cast<char>(float(*this)); } @@ -121,15 +125,48 @@ struct bfloat16 { return static_cast<double>(float(*this)); } + static bfloat16 epsilon() { + bfloat16 x; + x.value = 0x3c00; // 0x1.0p-7 + return x; + } + uint16_t value; }; -inline bool operator==(const bfloat16 a, const bfloat16 b) { - return a.value == b.value; +inline bfloat16 operator+(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast<float>(a) + static_cast<float>(b)); } - -inline bool operator!=(const bfloat16 a, const bfloat16 b) { - return a.value != b.value; +inline bfloat16 operator-(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast<float>(a) - static_cast<float>(b)); +} +inline bfloat16 operator*(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast<float>(a) * static_cast<float>(b)); +} +inline bfloat16 operator/(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast<float>(a) / static_cast<float>(b)); +} +inline bfloat16 operator-(bfloat16 a) { + a.value ^= 0x8000; + return a; +} +inline bool operator<(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) < static_cast<float>(b); +} +inline bool operator<=(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) <= static_cast<float>(b); +} +inline bool operator==(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) == static_cast<float>(b); +} +inline bool operator!=(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) != static_cast<float>(b); +} +inline bool operator>(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) > static_cast<float>(b); +} +inline bool operator>=(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) >= static_cast<float>(b); } } // end namespace tensorflow diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc index c1511ebe34..9b24e3aa00 100644 --- a/tensorflow/core/framework/op_def_builder_test.cc +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -124,22 +124,23 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) { "attr: { name: 'a' type: 'type' allowed_values { list { type: " "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " "DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, " - "DT_QINT32, DT_UINT32, DT_UINT64] } } }"); + "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16] } } }"); ExpectSuccess( b().Attr("a:{numbertype, variant}"), "attr: { name: 'a' type: 'type' allowed_values { list { type: " "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " "DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, " - "DT_QINT32, DT_UINT32, DT_UINT64, DT_VARIANT] } } }"); + "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16, DT_VARIANT] } } }"); ExpectSuccess(b().Attr("a:realnumbertype"), "attr: { name: 'a' type: 'type' allowed_values { list { type: " "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, " - "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64] } } }"); + "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, " + "DT_BFLOAT16] } } }"); ExpectSuccess(b().Attr("a:{realnumbertype, variant , string, }"), "attr: { name: 'a' type: 'type' allowed_values { list { type: " "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, " "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, " - "DT_VARIANT, DT_STRING] } } }"); + "DT_BFLOAT16, DT_VARIANT, DT_STRING] } } }"); ExpectSuccess(b().Attr("a:quantizedtype"), "attr: { name: 'a' type: 'type' allowed_values { list { type: " "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }"); @@ -216,12 +217,14 @@ TEST_F(OpDefBuilderTest, AttrListOfRestricted) { b().Attr("a:list(realnumbertype)"), "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " - "DT_UINT16, DT_INT8, DT_HALF, DT_UINT32, DT_UINT64] } } }"); + "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64" + "] } } }"); ExpectSuccess( b().Attr("a:list({realnumbertype, variant})"), "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " - "DT_UINT16, DT_INT8, DT_HALF, DT_UINT32, DT_UINT64, DT_VARIANT] } } }"); + "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64, " + "DT_VARIANT] } } }"); ExpectSuccess( b().Attr("a:list(quantizedtype)"), "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc index faae19585d..48849f9dda 100644 --- a/tensorflow/core/framework/types.cc +++ b/tensorflow/core/framework/types.cc @@ -206,18 +206,18 @@ string DataTypeSliceString(const DataTypeSlice types) { } DataTypeVector AllTypes() { - return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, - DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128, - DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16, - DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE, DT_VARIANT, - DT_UINT32, DT_UINT64}; + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, + DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128, + DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16, + DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE, DT_VARIANT, + DT_UINT32, DT_UINT64, DT_BFLOAT16}; } #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) DataTypeVector RealNumberTypes() { - return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, - DT_INT8, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64}; + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, + DT_INT8, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64, DT_BFLOAT16}; } DataTypeVector QuantizedTypes() { @@ -227,14 +227,14 @@ DataTypeVector QuantizedTypes() { DataTypeVector RealAndQuantizedTypes() { return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT16, DT_INT8, DT_QINT8, DT_QUINT8, - DT_QINT16, DT_QUINT16, DT_QINT32, DT_HALF}; + DT_QINT16, DT_QUINT16, DT_QINT32, DT_HALF, DT_BFLOAT16}; } DataTypeVector NumberTypes() { - return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, - DT_UINT8, DT_UINT16, DT_INT16, DT_INT8, - DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, - DT_QINT32, DT_HALF, DT_UINT32, DT_UINT64}; + return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, + DT_UINT16, DT_INT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, + DT_QINT8, DT_QUINT8, DT_QINT32, DT_HALF, DT_UINT32, + DT_UINT64, DT_BFLOAT16}; } #elif defined(__ANDROID_TYPES_FULL__) diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index f02cb51038..f1edbbb602 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", ], ) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index dd389de636..ec44d11bdd 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/grappler/costs/utils.h" +#include "tensorflow/core/grappler/utils.h" namespace tensorflow { namespace grappler { @@ -264,6 +265,79 @@ bool IsEnterWithQueue(const Node& node) { return false; } +bool HasAnyUnknownDimensions(const TensorShapeProto& proto) { + if (proto.unknown_rank()) { + return true; + } + for (const auto& dim : proto.dim()) { + if (dim.size() < 0) { + return true; + } + } + return false; +} + +void VerboseLogUnknownDimensionSources( + const Graph& graph, + const std::map<string, std::vector<OpInfo::TensorProperties>>& + input_properties_map, + const std::map<string, std::vector<OpInfo::TensorProperties>>& + output_properties_map) { + if (!VLOG_IS_ON(2)) { + return; + } + + VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:"; + + // Find all nodes in the graph for which we + // do not have any unknown dimensions in their inputs, but + // we have some unknown dimensions in their outputs. + for (const Node* const node : graph.nodes()) { + if (node->num_outputs() == 0) { + continue; + } + + const auto& input_properties = input_properties_map.at(node->name()); + const auto& output_properties = output_properties_map.at(node->name()); + + bool has_unknown_inputs = false; + for (int i = 0; i < node->num_inputs(); ++i) { + if (HasAnyUnknownDimensions(input_properties[i].shape())) { + has_unknown_inputs = true; + break; + } + } + + if (has_unknown_inputs) { + continue; + } + + for (int i = 0; i < node->num_outputs(); ++i) { + if (HasAnyUnknownDimensions(output_properties[i].shape())) { + string inputs = "input_shapes=["; + for (int i = 0; i < node->num_inputs(); ++i) { + inputs += + PartialTensorShape::DebugString(input_properties[i].shape()); + } + inputs += "]"; + + string outputs = "output_shapes=["; + for (int i = 0; i < node->num_outputs(); ++i) { + outputs += + PartialTensorShape::DebugString(output_properties[i].shape()); + } + outputs += "]"; + + VLOG(2) << "Node: " << node->name() << ", Op: " << node->def().op() + << ", " << inputs << ", " << outputs; + + // don't log again for this node + break; + } + } + } +} + } // namespace // Queue of nodes to process. Nodes can be enqueued in any order, but will be @@ -312,9 +386,15 @@ class SymbolicShapeRefiner { Status UpdateNode(const Node* node, bool relax, bool* refined) { return shape_refiner_->UpdateNode(node, relax, refined); } - Status SetShape(const Node* node, int output_port, - shape_inference::ShapeHandle shape) { - return shape_refiner_->SetShape(node, output_port, shape); + Status SetUnknownShape(const Node* node, int output_port) { + shape_inference::ShapeHandle shape = + GetUnknownOutputShape(node, output_port); + InferenceContext* ctx = GetContext(node); + if (ctx == nullptr) { + return errors::InvalidArgument("Missing context"); + } + ctx->set_output(output_port, shape); + return Status::OK(); } struct ShapeId { @@ -646,6 +726,23 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner, return Status::OK(); } +Status GraphProperties::OverwriteFedPorts( + SymbolicShapeRefiner* shape_refiner, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, + const Node* node, TopoQueue* new_shapes) const { + auto it = fed_ports.find(node->name()); + Status status; + if (it != fed_ports.end()) { + // It is possible to feed node output ports with tensors of any shape: as a + // result, the shape of a fed port is completely unknown. + for (const int output_port : it->second) { + status.Update(shape_refiner->SetUnknownShape(node, output_port)); + } + new_shapes->push(node); + } + return status; +} + // Manually propagate the input shape for Enter nodes and update any Merge node // outputs. Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner, @@ -673,9 +770,10 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner, return Status::OK(); } -Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner, - bool relax, const Node* n, - TopoQueue* new_shapes) { +Status GraphProperties::UpdateShapes( + SymbolicShapeRefiner* shape_refiner, bool relax, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, + const Node* n, TopoQueue* new_shapes) const { if (n->IsEnter()) { // The Enter shape function always forwards an UnknownShape, so do the right // thing here. @@ -695,7 +793,9 @@ Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner, } } } - return Status::OK(); + // Nodes can be fed with any shape. The TensorFlow shape inference code can't + // handle this properly, so overwrite its behavior here. + return OverwriteFedPorts(shape_refiner, fed_ports, n, new_shapes); } // Propagates the shapes in the transitive fan-out of <new_shapes>. @@ -703,6 +803,7 @@ Status GraphProperties::PropagateShapes( SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes, const std::unordered_map<const Node*, std::unordered_set<const Node*>>& resources, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, int num_loops) const { // Limit the number of iterations to prevent infinite loops in the presence of // incorrect shape functions. The algoritm should converge in at most @@ -728,8 +829,8 @@ Status GraphProperties::PropagateShapes( for (const Edge* e : n->out_edges()) { if (!e->IsControlEdge()) { const Node* fanout = e->dst(); - TF_RETURN_IF_ERROR( - UpdateShapes(shape_refiner, relax, fanout, new_shapes)); + TF_RETURN_IF_ERROR(UpdateShapes(shape_refiner, relax, fed_ports, + fanout, new_shapes)); } } } @@ -803,7 +904,7 @@ Status GraphProperties::UpdateResource( return Status::OK(); } -Status GraphProperties::InferStatically() { +Status GraphProperties::InferStatically(bool assume_valid_feeds) { Graph graph(OpRegistry::Global()); FunctionLibraryDefinition function_library(graph.op_registry(), item_.graph.library()); @@ -820,11 +921,21 @@ Status GraphProperties::InferStatically() { Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); TF_RETURN_IF_ERROR(s); + std::unordered_map<string, std::unordered_set<int>> fed_ports; + if (!assume_valid_feeds) { + for (const auto& feed : item_.feed) { + int port_index = 0; + string node_name = ParseNodeName(feed.first, &port_index); + fed_ports[node_name].insert(port_index); + } + } + // List the resources and the nodes using them. Also collect the Enter and // Merge nodes. std::unordered_map<const Node*, std::unordered_set<const Node*>> resources; std::unordered_set<const Node*> enter_nodes; std::unordered_set<const Node*> merge_nodes; + std::unordered_set<const Node*> fed_nodes; int num_loops = 0; for (const Node* const node : graph.nodes()) { for (int i = 0; i < node->num_inputs(); ++i) { @@ -841,6 +952,9 @@ Status GraphProperties::InferStatically() { } else if (node->IsNextIteration()) { ++num_loops; } + if (fed_ports.find(node->name()) != fed_ports.end()) { + fed_nodes.insert(node); + } } SymbolicShapeRefiner refiner(&shape_refiner); @@ -855,15 +969,22 @@ Status GraphProperties::InferStatically() { // Force the propagation of shapes of Enter nodes manually (the Enter shape // function always forwards an UnknownShape). for (const Node* node : enter_nodes) { - TF_RETURN_IF_ERROR(UpdateShapes(&refiner, relax, node, &new_shapes)); + TF_RETURN_IF_ERROR( + UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes)); } // Seed the propagation of shapes through merge nodes. for (const Node* node : merge_nodes) { - TF_RETURN_IF_ERROR(UpdateShapes(&refiner, relax, node, &new_shapes)); + TF_RETURN_IF_ERROR( + UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes)); + } + // Also seed the propagation of shapes in the fanout of fed nodes. + for (const Node* node : fed_nodes) { + TF_RETURN_IF_ERROR( + OverwriteFedPorts(&refiner, fed_ports, node, &new_shapes)); } // Propagate shapes normally. - TF_RETURN_IF_ERROR( - PropagateShapes(&refiner, relax, &new_shapes, resources, num_loops)); + TF_RETURN_IF_ERROR(PropagateShapes(&refiner, relax, &new_shapes, resources, + fed_ports, num_loops)); } // Track shapes globally across the graph. @@ -874,6 +995,10 @@ Status GraphProperties::InferStatically() { if (!node_ctx) { continue; } + // Skip any information that comes from fed nodes. + if (fed_ports.find(node->name()) != fed_ports.end()) { + continue; + } for (const auto& merged_shapes : node_ctx->MergedShapes()) { if (!shape_manager.Merge(merged_shapes.first, merged_shapes.second) .ok()) { @@ -948,6 +1073,10 @@ Status GraphProperties::InferStatically() { } } + // Help trace the unknown dimensions to their origins. + VerboseLogUnknownDimensionSources(graph, input_properties_, + output_properties_); + return Status::OK(); } diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index 95bc5044d0..6fc53a7f2e 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -34,12 +34,19 @@ class TopoQueue; // nodes, and potentially a set of nodes to feed. class GraphProperties { public: - // Factory method for creating a GrapplerShapes from a MetaGraphDef. - // Returns nullptr if the given meta_graph cannot be converted. explicit GraphProperties(const GrapplerItem& item) : item_(item) {} - Status InferStatically(); + // Infer the shapes through abstract interpretation. Feed information can be + // incorrect so it should be discarded to ensure correctness of the analysis. + // However, it can help infer shapes in the fanout of fed nodes (even though + // the correctness of these shapes can't be guaranteed), so in some cases + // (such as simulation or scheduling) it makes sense of keep these shapes. + Status InferStatically(bool assume_valid_feeds); + // Infer the shape by running the graph on the specified cluster and recording + // the shapes of the processed tensors. Status InferDynamically(Cluster* cluster); + // Extract the properties from a cost graph. For testing only since there is + // no way to ensure that the cost graph match the item. Status InferFromCostGraph(const CostGraphDef& cost_graph); // Stores `item_.graph` with the inferred output shapes to `output_graph_def`. @@ -65,12 +72,6 @@ class GraphProperties { OpInfo::TensorProperties*); private: - // Inputs - GrapplerItem item_; - std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_; - std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_; - const std::vector<OpInfo::TensorProperties> missing_properties_; - // Merges shapes <shapes_and_types>, determined from an EnqueueV2 node, into // <*queue_shapes_and_types>. static Status MergeEnqueueShapesAndTypes( @@ -99,17 +100,31 @@ class GraphProperties { static Status UpdateEnter(SymbolicShapeRefiner* shape_refiner, const Node* node, bool relax, TopoQueue* new_shapes); + // Process a node that is used to feed the model. + Status OverwriteFedPorts( + SymbolicShapeRefiner* shape_refiner, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, + const Node* node, TopoQueue* new_shapes) const; // Update the shapes for node 'n'. If output shapes for n have changed, // enqueue its fanout in 'new_shapes'. - static Status UpdateShapes(SymbolicShapeRefiner* shape_refiner, bool relax, - const Node* n, TopoQueue* new_shapes); + Status UpdateShapes( + SymbolicShapeRefiner* shape_refiner, bool relax, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, + const Node* n, TopoQueue* new_shapes) const; // Propagate the shapes for the nodes enqueued in new_shapes and their // transitive fanout until a fixed point is reached. Status PropagateShapes( SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes, const std::unordered_map<const Node*, std::unordered_set<const Node*>>& resources, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, int num_loops) const; + + // Data members + GrapplerItem item_; + std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_; + std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_; + const std::vector<OpInfo::TensorProperties> missing_properties_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index c11af5777a..cc40ff2cfc 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -73,7 +73,7 @@ TEST_F(GraphPropertiesTest, StaticProperties) { CHECK(fake_input.NextItem(&item)); GraphProperties properties(item); - Status s = properties.InferStatically(); + Status s = properties.InferStatically(true); TF_CHECK_OK(s); for (const auto& node : item.graph.node()) { @@ -179,7 +179,7 @@ TEST_F(GraphPropertiesTest, Variables) { { GraphProperties static_properties(item); - TF_CHECK_OK(static_properties.InferStatically()); + TF_CHECK_OK(static_properties.InferStatically(false)); const auto props = static_properties.GetOutputProperties("Var"); EXPECT_EQ(1, props.size()); @@ -219,7 +219,7 @@ TEST_F(GraphPropertiesTest, VarHandles) { .Finalize(item.graph.add_node())); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto props = properties.GetOutputProperties("VarRead"); EXPECT_EQ(1, props.size()); @@ -286,7 +286,7 @@ TEST_F(GraphPropertiesTest, Queues) { TF_CHECK_OK(root.ToGraphDef(&item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto props1 = properties.GetOutputProperties("Dequeue1"); ASSERT_EQ(1, props1.size()); @@ -335,7 +335,7 @@ TEST_F(GraphPropertiesTest, MergeWithoutLoops) { "merge_without_loops.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"}; std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]", @@ -377,7 +377,7 @@ TEST_F(GraphPropertiesTest, WhileLoop) { "while_loop.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1", "while/Exit_1"}; @@ -435,7 +435,7 @@ TEST_F(GraphPropertiesTest, NestedLoop) { "nested_loop.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1", "while/Exit_1"}; @@ -498,7 +498,7 @@ TEST_F(GraphPropertiesTest, LoopsAndQueues) { "loops_and_queues.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1", "while/Exit_1"}; @@ -556,7 +556,7 @@ TEST_F(GraphPropertiesTest, LoopsAndResourceVars) { "loops_and_resource_vars.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1", "while/Exit_1"}; @@ -608,7 +608,7 @@ TEST_F(GraphPropertiesTest, QueuesAndLoops) { "queues_and_loops.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1", "while/Exit_1"}; @@ -657,7 +657,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape) { item.fetch.push_back("init_restore"); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto restore_props = properties.GetOutputProperties("restore"); const OpInfo::TensorProperties& restore_prop = restore_props[0]; @@ -704,7 +704,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) { item.fetch.push_back("init2"); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto props = properties.GetOutputProperties("restore"); const OpInfo::TensorProperties& prop = props[0]; @@ -732,7 +732,7 @@ TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) { "simple_function.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto props = properties.GetOutputProperties("MyAdd_55e046a8_1"); const OpInfo::TensorProperties& prop = props[0]; EXPECT_EQ(DT_FLOAT, prop.dtype()); @@ -766,7 +766,7 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto shape_a = properties.GetOutputProperties("a").at(0).shape(); const auto shape_c = properties.GetOutputProperties("c").at(0).shape(); EXPECT_EQ(2, shape_a.dim_size()); @@ -822,7 +822,7 @@ TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) { GraphProperties properties(item); // This function should return OK, since it doesn't validate the colocation // constraints internally. - TF_EXPECT_OK(properties.InferStatically()); + TF_EXPECT_OK(properties.InferStatically(false)); } TEST_F(GraphPropertiesTest, ShapeTracking) { @@ -842,7 +842,7 @@ TEST_F(GraphPropertiesTest, ShapeTracking) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto shape_a = properties.GetOutputProperties("a").at(0).shape(); const auto shape_b = properties.GetOutputProperties("b").at(0).shape(); const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape(); @@ -851,6 +851,65 @@ TEST_F(GraphPropertiesTest, ShapeTracking) { EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString()); } +TEST_F(GraphPropertiesTest, FedNodes) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, + cluster_->GetDeviceNames()); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + { + // Conservative shape analysis: the shape of fed ports should be unknown + GraphProperties properties(item); + Status s = properties.InferStatically(false); + TF_CHECK_OK(s); + for (const auto& node : item.graph.node()) { + if (node.op() == "Const") { + continue; + } + const auto in_props = properties.GetInputProperties(node.name()); + EXPECT_EQ(1, in_props.size()); + const OpInfo::TensorProperties& in_prop = in_props[0]; + const auto out_props = properties.GetOutputProperties(node.name()); + EXPECT_EQ(1, out_props.size()); + const OpInfo::TensorProperties& out_prop = out_props[0]; + + if (node.name() == "x") { + // x is fed: its input should have a known shape, while its output + // doesn't + EXPECT_FALSE(in_prop.shape().unknown_rank()); + EXPECT_EQ(1, in_prop.shape().dim_size()); + EXPECT_EQ(2, in_prop.shape().dim(0).size()); + EXPECT_TRUE(out_prop.shape().unknown_rank()); + } else if (node.op() == "Square" || node.op() == "AddN") { + // These nodes are in the fanout of x: their shapes should be unknown. + EXPECT_TRUE(in_prop.shape().unknown_rank()); + EXPECT_TRUE(out_prop.shape().unknown_rank()); + } + } + } + { + // Optimistic shape analysis: the shape of fed ports should be derived from + // the shape of the fanin. + GraphProperties properties(item); + Status s = properties.InferStatically(true); + TF_CHECK_OK(s); + for (const auto& node : item.graph.node()) { + if (node.op() == "Square" || node.op() == "AddN") { + const auto in_props = properties.GetInputProperties(node.name()); + EXPECT_EQ(1, in_props.size()); + const OpInfo::TensorProperties& in_prop = in_props[0]; + EXPECT_EQ(DT_FLOAT, in_prop.dtype()); + EXPECT_FALSE(in_prop.shape().unknown_rank()); + EXPECT_EQ(2, in_prop.shape().dim_size()); + const auto out_props = properties.GetOutputProperties(node.name()); + EXPECT_EQ(1, out_props.size()); + const OpInfo::TensorProperties& out_prop = out_props[0]; + EXPECT_EQ(in_prop.DebugString(), out_prop.DebugString()); + } + } + } +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index e5e1ee3292..6640de668d 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -122,7 +122,7 @@ Status VirtualScheduler::Init() { // Construct graph properties. Status status; if (use_static_shapes_) { - status = graph_properties_.InferStatically(); + status = graph_properties_.InferStatically(true); } else { status = graph_properties_.InferDynamically(cluster_); } diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 36c7f92c49..da99777bbc 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -173,7 +173,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( << ", skipping this input."; return nullptr; } - LOG(INFO) << "Will use feed node " << feed_name; + VLOG(1) << "Will use feed node " << feed_name; new_item->feed.emplace_back(feed_name, Tensor()); } @@ -188,7 +188,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( << ", skipping this input"; return nullptr; } - LOG(INFO) << "Will use fetch node " << name; + VLOG(1) << "Will use fetch node " << name; new_item->fetch.push_back(name); } } diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 5d9eb8e0b1..7b4ed10e7e 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -96,6 +96,7 @@ cc_library( ":graph_optimizer", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", @@ -332,6 +333,11 @@ tf_cc_test( deps = [ ":layout_optimizer", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 3cfc4f61e4..efe8ac05a3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -253,6 +253,30 @@ bool IsNumberType(DataType dtype) { const char kOutputShapesAttr[] = "_output_shapes"; +PartialTensorShape GetInputShape(const string& input, const NodeMap& node_map) { + int output_pos; + string node_name = ParseNodeName(input, &output_pos); + const NodeDef* input_node = node_map.GetNode(node_name); + return input_node->attr().at(kOutputShapesAttr).list().shape(output_pos); +} + +bool ShapesEqual(const string& input_x, const string& input_y, + const NodeMap& node_map) { + PartialTensorShape x_shape = GetInputShape(input_x, node_map); + PartialTensorShape y_shape = GetInputShape(input_y, node_map); + if (x_shape.unknown_rank() || y_shape.unknown_rank() || + x_shape.dims() != y_shape.dims()) { + return false; + } + for (int i = 0; i < x_shape.dims(); ++i) { + if (x_shape.dim_size(i) == -1 || y_shape.dim_size(i) == -1 || + x_shape.dim_size(i) != y_shape.dim_size(i)) { + return false; + } + } + return true; +} + // Returns whether `reshape` is an identity op. The tensor that `reshape` // reshapes is the `output_pos`-th output of node `input`. bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input, @@ -868,8 +892,11 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // multiplication over addition to hoist common factors out of aggregate nodes // where all the inputs are Mul nodes. This pattern occurs frequently in // regularization terms for the gradients during training. - // TODO(rmlarsen): Check shapes and enable for AddN. - if (IsAdd(*node) && NumNonControlInputs(*node) > 1 && + // For example, we can rewrite an expression of the form: + // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn)) + // to the following: + // Mul(x, AddN(y1, y2, y3, ... yn)) + if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 && !OptimizedNodeExists(StrCat(node->name(), "_hoist_add"))) { // Determine the set of common factors if the input nodes are all Mul nodes. std::set<string> common_factors; @@ -899,24 +926,15 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } if (common_factors.size() == 1) { const string& common_factor = *common_factors.begin(); - // In this case we have an expression of the form - // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn)) - // that can be rewritten as - // Mul(x, AddN(y1, y2, y3, ... yn)) - - // 1. Use a copy of the first Mul node for the outer multiplication. - NodeDef* new_mul_node = AddNode(StrCat(node->name(), "_hoist_mul"), - node_map_->GetNode(node->input(0))); - NodeDef* new_add_node = AddNode(StrCat(node->name(), "_hoist_add"), node); - new_mul_node->set_device(node->device()); - new_mul_node->set_input(0, common_factor); - node_map_->AddOutput(common_factor, new_mul_node->name()); - new_mul_node->set_input(1, new_add_node->name()); - node_map_->AddOutput(new_add_node->name(), new_mul_node->name()); - - // 2. Hoist non-shared factors up into the new AddN node. - nodes_to_simplify->PushBack(new_add_node); - for (int i = 0; i < node->input_size(); ++i) { + + // Gather up the non-shared factors (the y's in the example). + // Unless the aggregation is Add, we have to make sure that all the y's + // have the same shape since the other aggregation ops do not support + // broadcasting. + std::vector<string> unique_factors; + unique_factors.reserve(node->input_size()); + bool shapes_match = true; + for (int i = 0; i < node->input_size() && shapes_match; ++i) { const string& input = node->input(i); if (IsControlInput(input)) { break; @@ -924,15 +942,41 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* mul_node = node_map_->GetNode(input); const int unique_factor_index = mul_node->input(0) == common_factor ? 1 : 0; - const string unique_factor = mul_node->input(unique_factor_index); - new_add_node->set_input(i, unique_factor); + unique_factors.push_back(mul_node->input(unique_factor_index)); + if (i > 0 && !IsAdd(*node)) { + shapes_match = ShapesEqual(unique_factors.front(), + unique_factors.back(), *node_map_); + } } - // 4. Add frame dependencies that the original node might have had. - AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor, - {new_add_node}); + if (shapes_match) { + // 1. Use a copy of the first Mul node for the outer multiplication. + NodeDef* new_mul_node = AddNode(StrCat(node->name(), "_hoist_mul"), + node_map_->GetNode(node->input(0))); + NodeDef* new_add_node = + AddNode(StrCat(node->name(), "_hoist_add"), node); + new_mul_node->set_device(node->device()); + new_mul_node->set_input(0, common_factor); + node_map_->AddOutput(common_factor, new_mul_node->name()); + new_mul_node->set_input(1, new_add_node->name()); + node_map_->AddOutput(new_add_node->name(), new_mul_node->name()); + + // 2. Hoist non-shared factors up into the new AddN node. + nodes_to_simplify->PushBack(new_add_node); + for (int i = 0; i < node->input_size(); ++i) { + const string& input = node->input(i); + if (IsControlInput(input)) { + break; + } + new_add_node->set_input(i, unique_factors[i]); + } - return new_mul_node->name(); + // 3. Add frame dependencies that the original node might have had. + AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor, + {new_add_node}); + + return new_mul_node->name(); + } } } @@ -1064,13 +1108,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, int num_frames; TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_, &frame_map_, &num_frames)); - if (opt_level_ == RewriterConfig::AGGRESSIVE) { - graph_properties_.reset(new GraphProperties(item)); - // Shapes are only needed in aggressive mode. - TF_RETURN_IF_ERROR(graph_properties_->InferStatically()); - TF_RETURN_IF_ERROR( - graph_properties_->AnnotateOutputShapes(optimized_graph_)); - } + graph_properties_.reset(new GraphProperties(item)); + // Shapes are only needed in aggressive mode. + TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false)); + TF_RETURN_IF_ERROR(graph_properties_->AnnotateOutputShapes(optimized_graph_)); // Perform the optimizations. DedupComputations(); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index e8a18ff9d9..80f42694d9 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -32,6 +32,21 @@ string OptimizedName(const string& name) { return AddPrefixToNodeName(name, kArithmeticOptimizer); } +void VerifyGraphsMatch(const GraphDef& original_graph, + const GraphDef& optimized_graph, int line) { + EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line; + for (int i = 0; i < original_graph.node_size(); ++i) { + const NodeDef& original = original_graph.node(i); + const NodeDef& optimized = optimized_graph.node(i); + EXPECT_EQ(original.name(), optimized.name()) << line; + EXPECT_EQ(original.op(), optimized.op()) << line; + EXPECT_EQ(original.input_size(), optimized.input_size()) << line; + for (int j = 0; j < original.input_size(); ++j) { + EXPECT_EQ(original.input(j), optimized.input(j)) << line; + } + } +} + class ArithmeticOptimizerTest : public ::testing::Test {}; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -44,18 +59,7 @@ TEST_F(ArithmeticOptimizerTest, NoOp) { GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - - EXPECT_EQ(item.graph.node_size(), output.node_size()); - for (int i = 0; i < item.graph.node_size(); ++i) { - const NodeDef& original = item.graph.node(i); - const NodeDef& optimized = output.node(i); - EXPECT_EQ(original.name(), optimized.name()); - EXPECT_EQ(original.op(), optimized.op()); - EXPECT_EQ(original.input_size(), optimized.input_size()); - for (int j = 0; j < original.input_size(); ++j) { - EXPECT_EQ(original.input(j), optimized.input(j)); - } - } + VerifyGraphsMatch(item.graph, output, __LINE__); } TEST_F(ArithmeticOptimizerTest, OpDedupping) { @@ -398,39 +402,51 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { } TEST_F(ArithmeticOptimizerTest, HoistFactor) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); - Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2}); - Output y2 = ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2}); - Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1); - Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x); - Output add = ops::Add(s.WithOpName("add"), mul1, mul2); - Output id = ops::Identity(s.WithOpName("id"), add); - - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - - ArithmeticOptimizer optimizer; - GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - - EXPECT_EQ(9, output.node_size()); - const NodeDef& new_add = output.node(8); - EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name()); - EXPECT_EQ("y1", new_add.input(0)); - EXPECT_EQ("y2", new_add.input(1)); - const NodeDef& new_mul = output.node(7); - EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name()); - EXPECT_EQ("x", new_mul.input(0)); - EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1)); - const NodeDef& new_id = output.node(6); - EXPECT_EQ("id", new_id.name()); - EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0)); + for (bool matching_shapes : {true, false}) { + for (bool use_addn : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2}); + Output y2 = matching_shapes + ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2}) + : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1}); + Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1); + Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x); + Output id = + use_addn ? ops::Identity(s.WithOpName("id"), + ops::AddN(s.WithOpName("add"), {mul1, mul2})) + : ops::Identity(s.WithOpName("id"), + ops::Add(s.WithOpName("add"), mul1, mul2)); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + if (use_addn && !matching_shapes) { + VerifyGraphsMatch(item.graph, output, __LINE__); + } else { + EXPECT_EQ(9, output.node_size()); + const NodeDef& new_add = output.node(8); + EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name()); + EXPECT_EQ("y1", new_add.input(0)); + EXPECT_EQ("y2", new_add.input(1)); + const NodeDef& new_mul = output.node(7); + EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name()); + EXPECT_EQ("x", new_mul.input(0)); + EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1)); + const NodeDef& new_id = output.node(6); + EXPECT_EQ("id", new_id.name()); + EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0)); + } + } + } } TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index c77b2badf4..e0f39c2931 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -30,13 +30,16 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/bcast.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" namespace tensorflow { namespace grappler { @@ -95,7 +98,38 @@ class DeviceSimple : public DeviceBase { std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_; }; +template <typename T> +bool AllValuesAre(const TensorProto& tensor, const T& value) { + // TensorProto represents the content of the tensor in either <type>_val or + // tensor_content. + typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values = + checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor)); + if (!tensor_values->empty()) { + for (const T& tensor_value : *tensor_values) { + if (tensor_value != value) { + return false; + } + } + return true; + } + const auto tensor_content_size = tensor.tensor_content().size(); + if (tensor_content_size > 0) { + CHECK_EQ(0, tensor_content_size % sizeof(T)); + std::vector<T> raw_values(tensor_content_size / sizeof(T)); + port::CopyToArray(tensor.tensor_content(), + reinterpret_cast<char*>(raw_values.data())); + for (int i = 0; i < tensor_content_size / sizeof(T); ++i) { + if (raw_values[i] != value) { + return false; + } + } + return true; + } + return false; +} + } // namespace + ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device) : opt_level_(opt_level), cpu_device_(cpu_device) { @@ -190,14 +224,21 @@ Status ConvertShapeToConstant(const string& op, const DataType& type, return Status::OK(); } -Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, - const GraphProperties& properties) { +bool ConstantFolding::IsReallyConstant(const NodeDef& node) const { + if (!IsConstant(node)) { + return false; + } + // If the node is fed it's not constant anymore. + return feed_nodes_.find(node.name()) == feed_nodes_.end(); +} + +Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { // We may add some nodes to the graph to encode control dependencies: there is // no need to process these, so only iterate over the nodes of the input // graph. - const int node_count = graph_.node_size(); + const int node_count = graph_->node_size(); for (int i = 0; i < node_count; ++i) { - NodeDef& node = *graph_.mutable_node(i); + NodeDef& node = *graph_->mutable_node(i); const string op = node.op(); if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") { continue; @@ -241,7 +282,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, // cases where the shape/rank/size would have been run in // the original graph. Additional inputs are extra control string ctrl_dep = - AddControlDependency(node.input(0), &graph_, node_map_.get()); + AddControlDependency(node.input(0), graph_, node_map_.get()); node.set_input(0, ctrl_dep); node_map_->AddOutput(NodeName(ctrl_dep), node.name()); } else { @@ -256,7 +297,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, AddPrefixToNodeName(strings::StrCat(node.name(), "-", j), kConstantFoldingConst); if (node_map_->GetNode(const_name) == nullptr) { - NodeDef* added_node = graph_.add_node(); + NodeDef* added_node = graph_->add_node(); added_node->set_name(const_name); added_node->set_op("Const"); added_node->set_device(node.device()); @@ -267,7 +308,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, // We add a control dependency to the original ShapeN node, // so that the node will only be run if all inputs of the // original ShapeN node are run. - string ctrl_dep = AddControlDependency(node.name(), &graph_, + string ctrl_dep = AddControlDependency(node.name(), graph_, node_map_.get()); *added_node->add_input() = ctrl_dep; node_map_->AddOutput(NodeName(ctrl_dep), added_node->name()); @@ -285,6 +326,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, return Status::OK(); } +namespace { bool ShapesEqual(const TensorShapeProto& shape1, const TensorShapeProto& shape2) { if (shape1.unknown_rank() || shape2.unknown_rank()) { @@ -297,11 +339,13 @@ bool ShapesEqual(const TensorShapeProto& shape1, if (shape1.dim(i).size() != shape2.dim(i).size()) { return false; } + if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) { + return false; + } } return true; } -namespace { bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties, BCast::Vec* shape, int64* min_id) { if (shape_node.op() == "Shape") { @@ -344,9 +388,9 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( const NodeDef* shape_node1 = node_map_->GetNode(node.input(0)); const NodeDef* shape_node2 = node_map_->GetNode(node.input(1)); if (shape_node1 == nullptr || - (shape_node1->op() != "Shape" && shape_node1->op() != "Const") || + (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) || shape_node2 == nullptr || - (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) { + (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) { return Status::OK(); } int64 min_id = 0; @@ -392,13 +436,13 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( strings::StrCat(node.name(), "-", j), kConstantFoldingConst); out[j] = node_map_->GetNode(const_name); if (out[j] == nullptr) { - out[j] = graph_.add_node(); + out[j] = graph_->add_node(); Tensor value(type, TensorShape({0})); *out[j] = CreateNodeDef(const_name, TensorValue(&value)); out[j]->set_device(node.device()); node_map_->AddNode(const_name, out[j]); string ctrl_dep = - AddControlDependency(node.name(), &graph_, node_map_.get()); + AddControlDependency(node.name(), graph_, node_map_.get()); *out[j]->add_input() = ctrl_dep; node_map_->AddOutput(NodeName(ctrl_dep), const_name); } @@ -426,7 +470,7 @@ Status ConstantFolding::MaterializeReductionIndices( return Status::OK(); } const NodeDef* indices = node_map_->GetNode(node->input(1)); - if (!indices || IsConstant(*indices)) { + if (!indices || IsReallyConstant(*indices)) { // The reduction indices are already constant, there's nothing to do. return Status::OK(); } @@ -479,7 +523,7 @@ Status ConstantFolding::MaterializeReductionIndices( if (node_map_->GetNode(const_name)) { return Status::OK(); } - NodeDef* reduction_indices = graph_.add_node(); + NodeDef* reduction_indices = graph_->add_node(); Tensor value(dtype, TensorShape({rank})); for (int i = 0; i < rank; ++i) { if (dtype == DT_INT32) { @@ -491,7 +535,7 @@ Status ConstantFolding::MaterializeReductionIndices( *reduction_indices = CreateNodeDef(const_name, TensorValue(&value)); reduction_indices->set_device(node->device()); string ctrl_dep = - AddControlDependency(node->input(1), &graph_, node_map_.get()); + AddControlDependency(node->input(1), graph_, node_map_.get()); *reduction_indices->add_input() = ctrl_dep; node_map_->AddNode(const_name, reduction_indices); node_map_->AddOutput(NodeName(ctrl_dep), const_name); @@ -504,10 +548,10 @@ Status ConstantFolding::MaterializeReductionIndices( } Status ConstantFolding::MaterializeConstants( - const GrapplerItem& item, const GraphProperties& properties) { - const int node_count = graph_.node_size(); + const GraphProperties& properties) { + const int node_count = graph_->node_size(); for (int i = 0; i < node_count; ++i) { - NodeDef& node = *graph_.mutable_node(i); + NodeDef& node = *graph_->mutable_node(i); const string& op = node.op(); if (op == "BroadcastGradientArgs") { TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties)); @@ -523,24 +567,23 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { if (node.input().empty()) { return false; } - // Skips nodes that must be preserved except whitelisted nodes. if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() && nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) { return false; } - - // Skips ops that don't benefit from folding. - const string& op = node.op(); - // Skip constants, they're already folded - if (op == "Const") { + // Skip control flow nodes, they can't be folded + if (ModifiesFrameInfo(node)) { return false; } - // Skip constrol flow nodes, they can't be folded - if (op == "Enter" || op == "RefEnter" || op == "Exit" || op == "RefExit" || - op == "NextIteration" || op == "RefNextIteration") { + // Skip constants, they're already folded + if (IsConstant(node)) { return false; } + + // Skips ops that don't benefit from folding. + const string& op = node.op(); + if (op.find("Placeholder") == 0) { return false; } @@ -594,7 +637,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { if (!input_node) { return false; } - bool is_const = IsConstant(*input_node); + bool is_const = IsReallyConstant(*input_node); if (!is_const && !is_merge) { return false; } @@ -612,6 +655,36 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return true; } +namespace { + +#define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \ + case DTYPE: \ + t->add_##NAME##_val(static_cast<TYPE>(value)); \ + break; + +Status CreateConstantTensorAttrValue(DataType type, double value, + const TensorShapeProto& shape, + AttrValue* attr_tensor) { + TensorProto* t = attr_tensor->mutable_tensor(); + *t->mutable_tensor_shape() = shape; + switch (type) { + SET_TENSOR_VAL_CASE(DT_FLOAT, float, float); + SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double); + SET_TENSOR_VAL_CASE(DT_INT64, int64, int64); + SET_TENSOR_VAL_CASE(DT_INT32, int32, int); + SET_TENSOR_VAL_CASE(DT_INT16, int32, int); + SET_TENSOR_VAL_CASE(DT_INT8, int32, int); + SET_TENSOR_VAL_CASE(DT_UINT8, int32, int); + SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool); + default: + return errors::InvalidArgument("Unsupported type: ", type); + } + return Status::OK(); +} + +#undef SET_TENSOR_CAL_CASE +} // namespace + // static NodeDef ConstantFolding::CreateNodeDef(const string& name, const TensorValue& tensor) { @@ -652,6 +725,14 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name, POPULATE_TENSOR_PROTO(tensor, t, int64, int64) } else if (tensor->dtype() == DT_INT32) { POPULATE_TENSOR_PROTO(tensor, t, int32, int) + } else if (tensor->dtype() == DT_INT16) { + POPULATE_TENSOR_PROTO(tensor, t, int16, int) + } else if (tensor->dtype() == DT_INT8) { + POPULATE_TENSOR_PROTO(tensor, t, int8, int) + } else if (tensor->dtype() == DT_UINT8) { + POPULATE_TENSOR_PROTO(tensor, t, uint8, int) + } else if (tensor->dtype() == DT_BOOL) { + POPULATE_TENSOR_PROTO(tensor, t, bool, bool) } } if (optimized) { @@ -720,7 +801,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, break; } const NodeDef* input_node = node_map_->GetNode(input); - if (!IsConstant(*input_node)) { + if (!IsReallyConstant(*input_node)) { return Status(error::INVALID_ARGUMENT, strings::StrCat("Can't fold ", node.name(), ", its ", input, " isn't constant")); @@ -774,7 +855,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) { continue; } NodeDef* input_node = node_map_->GetNode(input); - if (!IsConstant(*input_node)) { + if (!IsReallyConstant(*input_node)) { continue; } bool valid_input = true; @@ -955,8 +1036,8 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) { Status ConstantFolding::FoldGraph(GraphDef* output) { std::unordered_set<string> processed_nodes; std::deque<NodeDef*> queue; - for (int i = 0; i < graph_.node_size(); i++) { - auto node = graph_.mutable_node(i); + for (int i = 0; i < graph_->node_size(); i++) { + auto node = graph_->mutable_node(i); if (IsFoldable(*node)) { queue.push_back(node); } @@ -995,7 +1076,7 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { output->mutable_node()->DeleteSubrange(last + 1, output->node_size() - last - 1); - for (const auto& node : graph_.node()) { + for (const auto& node : graph_->node()) { // If no fetch nodes is provided, we conservatively // keep all nodes in the original graph in case users need to fetch // their values. @@ -1016,7 +1097,7 @@ bool ConstantFolding::IsSimplifiableReduction(const NodeDef& node) const { if (IsReduction(node)) { CHECK_LE(2, node.input_size()); const NodeDef* reductions_indices = node_map_->GetNode(node.input(1)); - if (IsConstant(*reductions_indices)) { + if (IsReallyConstant(*reductions_indices)) { TensorVector output; Status s = EvaluateNode(*reductions_indices, TensorVector(), &output); if (!s.ok()) { @@ -1040,7 +1121,7 @@ bool ConstantFolding::IsSimplifiableReshape( } CHECK_LE(2, node.input_size()); const NodeDef* new_shape = node_map_->GetNode(node.input(1)); - if (!IsConstant(*new_shape)) { + if (!IsReallyConstant(*new_shape)) { return false; } TensorVector outputs; @@ -1090,8 +1171,107 @@ bool ConstantFolding::IsSimplifiableReshape( return shape.IsCompatibleWith(new_dims); } +#define IS_VALUE_CASE(DTYPE, VALUE) \ + case DTYPE: \ + return AllValuesAre<EnumToDataType<DTYPE>::Type>( \ + node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE)) + +#define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1) +#define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0) + +bool ConstantFolding::IsOnes(const NodeDef& node) const { + if (feed_nodes_.find(node.name()) != feed_nodes_.end()) { + return false; + } + if (node.op() == "OnesLike") { + return true; + } + if (node.op() != "Const") { + return false; + } + const auto dtype = node.attr().at("dtype").type(); + switch (dtype) { + // IS_ONES_CASE(DT_HALF); + IS_ONES_CASE(DT_FLOAT); + IS_ONES_CASE(DT_DOUBLE); + IS_ONES_CASE(DT_UINT8); + IS_ONES_CASE(DT_INT8); + IS_ONES_CASE(DT_UINT16); + IS_ONES_CASE(DT_INT16); + IS_ONES_CASE(DT_INT32); + IS_ONES_CASE(DT_INT64); + IS_ONES_CASE(DT_COMPLEX64); + IS_ONES_CASE(DT_COMPLEX128); + default: + LOG(ERROR) << "Unexpected type " << DataTypeString(dtype); + return false; + } + return false; +} + +bool ConstantFolding::IsZeros(const NodeDef& node) const { + if (feed_nodes_.find(node.name()) != feed_nodes_.end()) { + return false; + } + if (node.op() == "ZerosLike") { + return true; + } + if (!IsConstant(node)) { + return false; + } + const auto dtype = node.attr().at("dtype").type(); + switch (dtype) { + // IS_ZEROS_CASE(DT_HALF); + IS_ZEROS_CASE(DT_FLOAT); + IS_ZEROS_CASE(DT_DOUBLE); + IS_ZEROS_CASE(DT_UINT8); + IS_ZEROS_CASE(DT_INT8); + IS_ZEROS_CASE(DT_UINT16); + IS_ZEROS_CASE(DT_INT16); + IS_ZEROS_CASE(DT_INT32); + IS_ZEROS_CASE(DT_INT64); + IS_ZEROS_CASE(DT_COMPLEX64); + IS_ZEROS_CASE(DT_COMPLEX128); + default: + LOG(ERROR) << "Unexpected type " << DataTypeString(dtype); + return false; + } + return false; +} + +void ConstantFolding::ReplaceAddOrMulWithIdentity(int input_to_forward, + NodeDef* node) { + node->set_op("Identity"); + // Propagate the designated input through the identity. + node->mutable_input()->SwapElements(0, input_to_forward); + // Add all other inputs as control dependencies. + for (int i = 1; i < node->input_size(); ++i) { + node->set_input(i, AsControlDependency(node->input(i))); + } + graph_modified_ = true; +} + +Status ConstantFolding::ReplaceAddOrMulWithConstant( + double value, const TensorShapeProto& shape, NodeDef* node) { + AttrValue tensor_attr; + TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(node->attr().at("T").type(), + value, shape, &tensor_attr)); + node->mutable_attr()->insert({"value", tensor_attr}); + node->set_op("Const"); + // Convert all inputs to control dependencies. + for (int i = 0; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) { + break; + } + node->set_input(i, AsControlDependency(node->input(i))); + } + graph_modified_ = true; + return Status::OK(); +} + Status ConstantFolding::SimplifyGraph(GraphDef* output, - const GraphProperties& properties) { + const GraphProperties& properties, + bool use_shape_info) { for (auto& node : *output->mutable_node()) { if (IsSimplifiableReduction(node)) { // Replace the reduction node with an identity node, that can be further @@ -1116,10 +1296,10 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, *node.add_input() = input; } } - // It's possible to feed a placeholder with a tensor that doesn't have the - // proper shape, and reshape this tensor later on. Therefore only remove - // reshapes in graphs that don't have placeholders. - if (IsSimplifiableReshape(node, properties)) { + const bool safe_to_use_shapes = + use_shape_info && + (feed_nodes_.empty() || opt_level_ == RewriterConfig::AGGRESSIVE); + if (safe_to_use_shapes && IsSimplifiableReshape(node, properties)) { const NodeDef* new_shape = node_map_->GetNode(node.input(1)); DataType output_type = node.attr().at("T").type(); node.set_op("Identity"); @@ -1134,6 +1314,63 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, *node.add_input() = input; } } + + // Simplify multiplication by ones or zeros, and addition of zeros. + bool is_mul = IsMul(node); + bool is_add = IsAdd(node); + if (opt_level_ == RewriterConfig::AGGRESSIVE && use_shape_info && + (is_mul || is_add) && properties.HasInputProperties(node.name()) && + properties.HasOutputProperties(node.name())) { + const NodeDef* x = node_map_->GetNode(node.input(0)); + const NodeDef* y = node_map_->GetNode(node.input(1)); + if (x == nullptr || y == nullptr) { + return errors::InvalidArgument("Invalid inputs to node: ", + node.DebugString()); + } + const TensorShapeProto& output_shape = + properties.GetOutputProperties(node.name())[0].shape(); + const TensorShapeProto& x_shape = + properties.GetInputProperties(node.name())[0].shape(); + + // Simplify multiplication by or addition of zeros. + const bool x_is_zero = IsZeros(*x); + const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); + if (x_is_zero && x_matches_output_shape) { + // 0 * y = 0 or 0 + y = y. + ReplaceAddOrMulWithIdentity(is_mul ? 0 : 1, &node); + continue; + } + const TensorShapeProto& y_shape = + properties.GetInputProperties(node.name())[1].shape(); + const bool y_is_zero = IsZeros(*y); + const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); + if (y_is_zero && y_matches_output_shape) { + // x * 0 = 0 or x + 0 = x. + ReplaceAddOrMulWithIdentity(is_mul ? 1 : 0, &node); + continue; + } + + if (is_mul) { + // Simplify multiplication by zeros where the output shape does not + // match the shape of the zero input. + if (x_is_zero || y_is_zero) { + TF_RETURN_IF_ERROR( + ReplaceAddOrMulWithConstant(0, output_shape, &node)); + continue; + } + + // Simplify multiplication by ones. + if (IsOnes(*x) && y_matches_output_shape) { + // 1 * y = y. + ReplaceAddOrMulWithIdentity(1, &node); + continue; + } else if (IsOnes(*y) && x_matches_output_shape) { + // x * 1 = x. + ReplaceAddOrMulWithIdentity(0, &node); + continue; + } + } + } } return Status::OK(); } @@ -1141,7 +1378,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, Status ConstantFolding::RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { - node_map_.reset(new NodeMap(&graph_)); + node_map_.reset(new NodeMap(graph_)); nodes_whitelist_.clear(); // Fold fetch nodes iff it has a single fanout. Note that if a fetch node // has a single fanout, it would be rewritten as a constant with the same @@ -1158,36 +1395,34 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, } GraphProperties properties(item); - const bool has_feed = !item.feed.empty(); - bool needs_shapes = !has_feed || opt_level_ == RewriterConfig::AGGRESSIVE; - Status s = errors::Unknown( - "The graph properties are needed but were not initialized"); - if (needs_shapes) { - s = properties.InferStatically(); - } - - if (!has_feed && s.ok()) { - // Only use static shape information when there is no feed in the - // graph. That's because it's possible to feed a placeholder with a tensor - // of any shape, which could make the static information inconsistent with - // the shapes actually fed. - TF_RETURN_IF_ERROR(MaterializeShapes(item, properties)); - } - if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) { - TF_RETURN_IF_ERROR(MaterializeConstants(item, properties)); + // It's possible to feed a placeholder with a tensor of any shape: make sure + // that the shape inference deals with this conservatively unless we're in + // aggressive mode. + const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; + Status s = properties.InferStatically(assume_valid_feeds); + const bool can_use_shape_info = s.ok(); + + if (can_use_shape_info) { + TF_RETURN_IF_ERROR(MaterializeShapes(properties)); + + if (opt_level_ == RewriterConfig::AGGRESSIVE) { + TF_RETURN_IF_ERROR(MaterializeConstants(properties)); + } } TF_RETURN_IF_ERROR(FoldGraph(output)); - if (!has_feed && s.ok()) { - TF_RETURN_IF_ERROR(SimplifyGraph(output, properties)); - } + TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info)); + return Status::OK(); } Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { nodes_to_preserve_ = item.NodesToPreserve(); + for (const auto& feed : item.feed) { + feed_nodes_.insert(NodeName(feed.first)); + } if (cpu_device_ == nullptr) { owned_device_.reset(new DeviceSimple()); @@ -1200,13 +1435,13 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, *output = item.graph; int64 node_count; do { - graph_.Swap(output); - item_to_optimize.graph = graph_; + graph_modified_ = false; + item_to_optimize.graph.Swap(output); + graph_ = &item_to_optimize.graph; *output = GraphDef(); - node_count = graph_.node_size(); + node_count = graph_->node_size(); TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, item_to_optimize, output)); - } while (output->node_size() != node_count); - + } while (graph_modified_ || output->node_size() != node_count); *output->mutable_library() = item.graph.library(); *output->mutable_versions() = item.graph.versions(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index f04f413c10..3bb9926338 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -51,16 +51,16 @@ class ConstantFolding : public GraphOptimizer { const GraphDef& optimize_output, double result) override; private: - Status MaterializeShapes(const GrapplerItem& item, - const GraphProperties& properties); + bool IsReallyConstant(const NodeDef& node) const; + + Status MaterializeShapes(const GraphProperties& properties); Status MaterializeBroadcastGradientArgs(const NodeDef& node, const GraphProperties& properties); Status MaterializeReductionIndices(NodeDef* node, const GraphProperties& properties); - Status MaterializeConstants(const GrapplerItem& item, - const GraphProperties& properties); + Status MaterializeConstants(const GraphProperties& properties); bool IsFoldable(const NodeDef& node) const; Status EvaluateNode(const NodeDef& node, @@ -72,12 +72,19 @@ class ConstantFolding : public GraphOptimizer { Status FoldNode(NodeDef* node, GraphDef* output_graph); + bool IsOnes(const NodeDef& node) const; + bool IsZeros(const NodeDef& node) const; + void ReplaceAddOrMulWithIdentity(int input_to_forward, NodeDef* node); + Status ReplaceAddOrMulWithConstant(double value, + const TensorShapeProto& shape, + NodeDef* node); Status FoldGraph(GraphDef* output); bool IsSimplifiableReduction(const NodeDef& node) const; bool IsSimplifiableReshape(const NodeDef& node, const GraphProperties& properties) const; - Status SimplifyGraph(GraphDef* output, const GraphProperties& properties); + Status SimplifyGraph(GraphDef* output, const GraphProperties& properties, + bool use_shape_info); Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, GraphDef* output); @@ -88,11 +95,13 @@ class ConstantFolding : public GraphOptimizer { std::unique_ptr<DeviceBase> owned_device_; std::unique_ptr<ResourceMgr> resource_mgr_; - GraphDef graph_; + GraphDef* graph_; std::unique_ptr<NodeMap> node_map_; std::unordered_set<string> nodes_to_preserve_; std::unordered_set<string> nodes_whitelist_; + std::unordered_set<string> feed_nodes_; bool has_fetch_; + bool graph_modified_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index b2d9b02c68..32a691d3ee 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -77,11 +77,166 @@ TEST_F(ConstantFoldingTest, SimpleFolding) { test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]); } +TEST_F(ConstantFoldingTest, NeutralElement) { + for (bool use_const : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({1, 2}))); + Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({1, 2}))); + Output zeros = + !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x) + : ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f}, {1, 2}); + Output zeros_broadcast = + ops::Const(s.WithOpName("zeros_broadcast"), {0.0f}, {1, 1}); + Output ones = !use_const + ? ops::OnesLike(s.WithOpName("ones"), x) + : ops::Const(s.WithOpName("ones"), {1.0f, 1.0f}, {1, 2}); + Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); + Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y); + Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones); + Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y); + Output mul5 = ops::Mul(s.WithOpName("mul1"), x, zeros_broadcast); + Output mul6 = ops::Mul(s.WithOpName("mul2"), zeros_broadcast, y); + Output add1 = ops::Add(s.WithOpName("add1"), x, zeros); + Output add2 = ops::Add(s.WithOpName("add2"), zeros, y); + Output addn = ops::AddN(s, {mul1, mul2, mul3, mul4, add1, add2}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(14, output.node_size()); + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + const string& name = node.name(); + if (name == "mul1") { + if (use_const) { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^x", node.input(0)); + } else { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("zeros", node.input(0)); + EXPECT_EQ("^x", node.input(1)); + } + } else if (name == "mul2") { + if (use_const) { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^y", node.input(0)); + } else { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("zeros", node.input(0)); + EXPECT_EQ("^y", node.input(1)); + } + } else if (name == "mul3") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^ones", node.input(1)); + } else if (name == "mul4") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("^ones", node.input(1)); + } else if (name == "mul5") { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^x", node.input(0)); + EXPECT_EQ("^ones", node.input(1)); + TensorProto t = node.attr().at("value").tensor(); + EXPECT_EQ(1, t.float_val_size()); + EXPECT_EQ(0, t.float_val(0)); + EXPECT_EQ(2, t.tensor_shape().dim_size()); + EXPECT_EQ(1, t.tensor_shape().dim(0).size()); + EXPECT_EQ(2, t.tensor_shape().dim(1).size()); + } else if (name == "mul6") { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^y", node.input(0)); + EXPECT_EQ("^ones", node.input(1)); + TensorProto t = node.attr().at("value").tensor(); + EXPECT_EQ(1, t.float_val_size()); + EXPECT_EQ(0, t.float_val(0)); + EXPECT_EQ(2, t.tensor_shape().dim_size()); + EXPECT_EQ(1, t.tensor_shape().dim(0).size()); + EXPECT_EQ(2, t.tensor_shape().dim(1).size()); + } else if (name == "add1") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^zeros", node.input(1)); + } else if (name == "add2") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("^zeros", node.input(1)); + } + } + } +} + +TEST_F(ConstantFoldingTest, CreateConstNodes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + +#define MAKE_TEST_GRAPH(TYPE) \ + Output TYPE##_const = \ + ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \ + Output TYPE##_mul = \ + ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const); \ + Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul) + + MAKE_TEST_GRAPH(float); + MAKE_TEST_GRAPH(double); + MAKE_TEST_GRAPH(int64); + MAKE_TEST_GRAPH(int32); + MAKE_TEST_GRAPH(int16); + MAKE_TEST_GRAPH(int8); + MAKE_TEST_GRAPH(uint8); +#undef MAKE_TEST_GRAPH + + Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5}); + Output bool_and = + ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const); + Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + ConstantFolding fold(nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(24, output.node_size()); + for (const NodeDef& node : output.node()) { +#define CHECK_RESULT(TYPE, FIELD) \ + if (node.name() == #TYPE "_mul") { \ + EXPECT_EQ(5, \ + node.attr().at("value").tensor().tensor_shape().dim(0).size()); \ + EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size()); \ + EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0)); \ + } + + CHECK_RESULT(float, float); + CHECK_RESULT(double, double); + CHECK_RESULT(int64, int64); + CHECK_RESULT(int32, int); + CHECK_RESULT(int16, int); + CHECK_RESULT(int8, int); + CHECK_RESULT(uint8, int); +#undef CHECK_RESULT + + if (node.name() == "bool_and") { + EXPECT_EQ(5, + node.attr().at("value").tensor().tensor_shape().dim(0).size()); + EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size()); + EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0)); + } + } +} + TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) { // Build a simple graph with a few trivially prunable ops. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output a = ops::Const(s.WithOpName("a"), 10, {3}); + Output a = ops::Const(s.WithOpName("a"), 10, {5}); auto b = ops::Unique(s.WithOpName("b"), {a}); Output c = ops::Identity(s.WithOpName("c"), {b.y}); Output d = ops::Identity(s.WithOpName("d"), {b.idx}); @@ -963,3 +1118,5 @@ TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { } // namespace } // namespace grappler } // namespace tensorflow + +// LocalWords: NewRootScope diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index d5563e9d4c..e9436638f0 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <deque> #include <unordered_set> #include "tensorflow/core/framework/attr_value.pb.h" @@ -69,6 +70,8 @@ std::set<string> GetOpsFormatSupported() { return ops_format_supported; } +// TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled +// because of the worse performance in some cases. std::set<string> GetOpsFormatAgnostic() { std::set<string> ops_format_agnostic = {"Add", "AddN", @@ -88,7 +91,7 @@ std::set<string> GetOpsFormatAgnostic() { "Split", "SquaredDifference", "Squeeze", - "Sub"}; + /*"Sum",*/ "Sub"}; return ops_format_agnostic; } @@ -186,33 +189,6 @@ class GraphProcessor { return node; } - NodeDef* AddNodeReductionConst(const string& name, const string& device) { - NodeDef* node = graph_->add_node(); - node_map_->AddNode(name, node); - node->set_name(name); - node->set_op("Const"); - AttrValue attr_data_type; - attr_data_type.set_type(DT_INT32); - node->mutable_attr()->insert({"dtype", attr_data_type}); - - AttrValue attr_tensor; - Tensor tensor(DT_INT32, TensorShape({3})); - std::vector<int> axis = {0, 2, 3}; - for (int i = 0; static_cast<size_t>(i) < axis.size(); i++) { - tensor.flat<int>()(i) = axis[i]; - } - tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); - node->mutable_attr()->insert({"value", attr_tensor}); - string device_name; - if (device.empty()) { - device_name = virtual_placer_.get_canonical_device_name(*node); - } else { - device_name = device; - } - node->set_device(device_name); - return node; - } - const VirtualPlacer& virtual_placer_; const std::unordered_set<string>& nodes_to_preserve_; GraphDef* graph_; @@ -370,10 +346,20 @@ class NodeProcessor : public GraphProcessor { LOG(ERROR) << "Failed to parse TensorProto."; } if (tensor.dims() == 1) { - int c = tensor.flat<int>()(3); - tensor.flat<int>()(3) = tensor.flat<int>()(2); - tensor.flat<int>()(2) = tensor.flat<int>()(1); - tensor.flat<int>()(1) = c; + if (tensor.flat<int>().size() == 4) { + int c = tensor.flat<int>()(3); + tensor.flat<int>()(3) = tensor.flat<int>()(2); + tensor.flat<int>()(2) = tensor.flat<int>()(1); + tensor.flat<int>()(1) = c; + } else if (tensor.flat<int>().size() == 3) { + tensor.flat<int>()(0) = 0; + tensor.flat<int>()(1) = 2; + tensor.flat<int>()(2) = 3; + } else { + return Status(error::INVALID_ARGUMENT, + strings::StrCat("Unsupported tensor size: ", + tensor.flat<int>().size())); + } } else if (tensor.dims() == 2) { for (int i = 0; i < 2; i++) { int c = tensor.matrix<int>()(3, i); @@ -394,7 +380,9 @@ class NodeProcessor : public GraphProcessor { Status UpdateAttrValueOfInput(int input_index) { auto input_node = node_map_->GetNode(node_->input(input_index)); // We created a copy of the node, so that we don't modify the original node, - // which might be used elsewhere. + // which might be used elsewhere. Note that this copy also copies the + // control dependency input in the case this node is inside a loop, + // to ensure added_node is in the same frame with node_. NodeDef* added_node = graph_->add_node(); *added_node = *input_node; string base_name = strings::StrCat(node_->name(), "-", input_node->name()); @@ -411,6 +399,14 @@ class NodeProcessor : public GraphProcessor { return input_pos; } + virtual std::set<int> GetOutputPos() const { + // For most nodes, no need to process control nodes or nodes that use an + // output other than the first output: only the first output is of + // 4D NCHW/NHWC format and thus relevant here. + std::set<int> output_pos = {0}; + return output_pos; + } + NodeDef* AddNodeTranspose(const string& node_name, const string& input_name, const string& const_name, DataType data_type, const TensorShapeProto& input_shape, @@ -476,37 +472,28 @@ class NodeProcessor : public GraphProcessor { auto outputs = node_map_->GetOutputs(node_->name()); string const_name = GetOrAddNodePermNCHWToNHWC(); for (const auto& output : outputs) { - string base_name = strings::StrCat(node_->name(), "-", output->name()); - string node_name = - AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-"); - // TODO(yaozhang): handle the rare case where node A is connected to more - // than one input of node B. - auto it = std::find_if(output->mutable_input()->begin(), - output->mutable_input()->end(), - [this](const string& input) { - string node_name = NodeName(input); - return node_name.compare(node_->name()) == 0; - }); - if (it == output->mutable_input()->end()) { - return Status(error::INVALID_ARGUMENT, - strings::StrCat("Expect ", node_->name(), - " to be an input of ", output->name())); - } - int output_pos = NodePosition(*it); - // No need to process control nodes or nodes that use an output - // other than the first output: only the first output is of 4D NCHW/NHWC - // format and thus relevant here. - if (output_pos != 0) { - continue; + for (int i = 0; i < output->input_size(); i++) { + auto& input = *output->mutable_input(i); + int input_port; + string input_name = ParseNodeName(input, &input_port); + auto output_pos = GetOutputPos(); + if (input_name == node_->name() && + output_pos.find(input_port) != output_pos.end()) { + string base_name = + strings::StrCat(node_->name(), "-", output->name(), "-", i); + string node_name = + AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-"); + TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); + TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes")); + AddNodeTranspose( + node_name, input, const_name, node_->attr().at("T").type(), + node_->attr().at("_output_shapes").list().shape(0), false); + input = node_name; + node_map_->AddOutput(node_->name(), node_name); + node_map_->AddOutput(node_name, output->name()); + } } - TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); - TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes")); - AddNodeTranspose( - node_name, node_->name(), const_name, node_->attr().at("T").type(), - node_->attr().at("_output_shapes").list().shape(0), false); - *it = node_name; - node_map_->UpdateOutput(node_->name(), output->name(), node_name); - node_map_->AddOutput(node_name, output->name()); + node_map_->RemoveOutput(node_->name(), output->name()); } return Status::OK(); } @@ -775,24 +762,52 @@ class AgnosticNodeProcessor : public NodeProcessor { bool IsNodeAfterNCHWToNHWC() const { std::set<string> ops_format_agnostic = GetOpsFormatAgnostic(); - auto node = node_map_->GetNode(node_->name()); - while (node->input_size() > 0) { - int data_input_pos = 0; - if (IsConcatV1(*node) || IsSplit(*node)) { - data_input_pos = 1; - } - node = node_map_->GetNode(node->input(data_input_pos)); - if (IsNodeNCHWToNHWC(node->name())) { + std::deque<NodeDef*> queue; + auto first_node_pos = DataInputPos(*node_); + for (const auto& pos : first_node_pos) { + auto input_node = node_map_->GetNode(node_->input(pos)); + queue.push_back(input_node); + } + // The code will exit this while loop in one iteration in most cases, as the + // graph is already topologically sorted. + while (!queue.empty()) { + NodeDef* current_node = queue.front(); + queue.pop_front(); + if (IsNodeNCHWToNHWC(current_node->name())) { return true; } - bool connected = - ops_format_agnostic.find(node->op()) != ops_format_agnostic.end(); - if (!connected) { - return false; + // We only continue searching if the path is connected through + // format-agnostic nodes. + if (ops_format_agnostic.find(current_node->op()) != + ops_format_agnostic.end()) { + auto current_node_pos = DataInputPos(*current_node); + for (const auto& pos : current_node_pos) { + auto input_node = node_map_->GetNode(current_node->input(pos)); + queue.push_back(input_node); + } } } return false; } + + private: + std::vector<int> DataInputPos(const NodeDef& node) const { + std::vector<int> pos; + if (IsSplit(node)) { + return {1}; + } + if (IsConcatV1(node)) { + return {1}; + } + if (IsAdd(node) || IsMul(node) || IsRealDiv(node) || + IsSquaredDifference(node) || IsSub(node)) { + return {0, 1}; + } + if (node.input_size() > 0 && !IsControlInput(node.input(0))) { + return {0}; + } + return {}; + } }; class AddNProcessor : public AgnosticNodeProcessor { @@ -815,42 +830,49 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { public: explicit BinaryOpProcessor(const OptimizeContext& opt_cxt) : AgnosticNodeProcessor(opt_cxt) { - is_4d_with_vector_ = Is4DOperateWithVector(); + is_4d_with_vector_ = IsNDOperateWithMD(4, 1); } protected: bool ShouldProcess() const override { + // TODO(yaozhang): Support IsNDOperateWithMD(1, 4): first input is a vector + // and the second input is a 4D tensor; and update CustomizedProcessing() + // accordingly. return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && - (Is4DOperateWithND(4) || Is4DOperateWithScalar() || - Is4DOperateWithVector()) && + (IsNDOperateWithMD(4, 0) || IsNDOperateWithMD(4, 1) || + IsNDOperateWithMD(4, 4) || IsNDOperateWithMD(0, 4)) && IsOnGPU(); } std::vector<int> GetInputPos() const override { - std::vector<int> input_pos = {0}; - if (Is4DOperateWithND(4)) { + std::vector<int> input_pos; + auto input0 = node_map_->GetNode(node_->input(0)); + auto input1 = node_map_->GetNode(node_->input(1)); + if (IsDimsFour(*input0)) { + input_pos.push_back(0); + } + if (IsDimsFour(*input1)) { input_pos.push_back(1); } return input_pos; } - bool Is4DOperateWithND(int n) const { + bool IsDimsFour(const NodeDef& node) const { + return NodeProcessor::IsDimsFour(node) || IsNodeNCHWToNHWC(node.name()); + } + + bool IsNDOperateWithMD(int n, int m) const { auto input0 = node_map_->GetNode(node_->input(0)); auto input1 = node_map_->GetNode(node_->input(1)); if (input0 && input1) { - return (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) && - ((n == 4) - ? (IsDimsFour(*input1) || IsNodeNCHWToNHWC(input1->name())) - : IsDimsN(*input1, n)); + bool input0_is_n = (n == 4) ? IsDimsFour(*input0) : IsDimsN(*input0, n); + bool input1_is_m = (m == 4) ? IsDimsFour(*input1) : IsDimsN(*input1, m); + return input0_is_n && input1_is_m; } return false; } - bool Is4DOperateWithScalar() const { return Is4DOperateWithND(0); } - - bool Is4DOperateWithVector() const { return Is4DOperateWithND(1); } - NodeDef* AddNodeShapeConst(const string& name, int num_channels) { NodeDef* node = graph_->add_node(); node_map_->AddNode(name, node); @@ -948,7 +970,7 @@ class ConcatProcessor : public AgnosticNodeProcessor { } Status CustomizedProcessing() override { - string concat_const_name = GetOrAddNodeConcatConst(); + string concat_const_name = AddNodeConcatConst()->name(); node_map_->AddOutput(concat_const_name, node_->name()); *node_->mutable_input(axis_node_pos_) = concat_const_name; return Status::OK(); @@ -956,8 +978,14 @@ class ConcatProcessor : public AgnosticNodeProcessor { bool IsAlongDimC() const { auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_)); + if (!IsConstant(*axis_node)) { + return false; + } if (axis_node->attr().find("value") != axis_node->attr().end()) { - return axis_node->attr().at("value").tensor().int_val(0) == 3; + auto tensor = axis_node->attr().at({"value"}).tensor(); + if (tensor.tensor_shape().dim_size() == 0 && tensor.int_val_size() == 1) { + return tensor.int_val(0) == 3; + } } return false; } @@ -965,28 +993,18 @@ class ConcatProcessor : public AgnosticNodeProcessor { int axis_node_pos_; private: - NodeDef* AddNodeConcatConst(const string& suffix, const string& depended_node, - const string& device) { - auto const_node = AddNodeConstScalar( - strings::StrCat(kConcatConst, "-", suffix), device, DT_INT32, 1); - // This is to ensure the concat node and the const node are - // in the same frame. - *const_node->add_input() = AsControlDependency(depended_node); - return const_node; - } - - string GetOrAddNodeConcatConst() { - string const_name; - if (is_in_frame_) { - int value_node_pos = (axis_node_pos_ == 0) ? 1 : 0; - auto const_node = AddNodeConcatConst( - node_->name(), NodeName(node_->input(value_node_pos)), - node_->device()); - const_name = const_node->name(); - } else { - const_name = kConcatConst; - } - return const_name; + NodeDef* AddNodeConcatConst() { + auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_)); + // We created a copy of the node, so that we don't modify the original node, + // which might be used elsewhere. Note that this copy also copies the + // control dependency input in the case this node is inside a loop, + // to ensure added_node is in the same frame with node_. + auto added_node = graph_->add_node(); + *added_node = *axis_node; + added_node->set_name(strings::StrCat(kConcatConst, "-", node_->name())); + added_node->mutable_attr()->at({"value"}).mutable_tensor()->set_int_val(0, + 1); + return added_node; } }; @@ -1036,6 +1054,16 @@ class SplitProcessor : public AgnosticNodeProcessor { return input_pos; } + std::set<int> GetOutputPos() const override { + std::set<int> output_pos{0}; + if (HasAttribute(*node_, "num_split").ok()) { + for (int i = 1; i < node_->attr().at("num_split").i(); i++) { + output_pos.insert(i); + } + } + return output_pos; + } + Status CustomizedProcessing() override { string split_const_name = AddNodeSplitConst()->name(); node_map_->AddOutput(split_const_name, node_->name()); @@ -1073,7 +1101,7 @@ class SplitProcessor : public AgnosticNodeProcessor { // We created a copy of the node, so that we don't modify the original node, // which might be used elsewhere. Note that this copy also copies the // control dependency input in the case this node is inside a loop, - // to ensure added_node is in the same frame with the Split node. + // to ensure added_node is in the same frame with node_. NodeDef* added_node = graph_->add_node(); *added_node = *dim_node; added_node->set_name(strings::StrCat(kSplitConst, "-", node_->name())); @@ -1329,20 +1357,21 @@ class SumProcessor : public AgnosticNodeProcessor { Status AddLayoutTransposeToOutputs() override { return Status::OK(); } - Status CustomizedProcessing() override { - node_map_->AddOutput(kReductionConst, node_->name()); - *node_->mutable_input(1) = GetOrAddNodeReductionConst(); - return Status::OK(); - } + Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); } private: bool IsAlongDimNHW() const { - NodeDef* node = node_map_->GetNode(node_->input(1)); + NodeDef* reduction_indices = node_map_->GetNode(node_->input(1)); + if (!IsConstant(*reduction_indices)) { + return false; + } Tensor tensor; - if (node->attr().find({"value"}) == node->attr().end()) { + if (reduction_indices->attr().find({"value"}) == + reduction_indices->attr().end()) { return false; } - auto success = tensor.FromProto(node->attr().at({"value"}).tensor()); + auto success = + tensor.FromProto(reduction_indices->attr().at({"value"}).tensor()); if (!success) { LOG(ERROR) << "Failed to parse TensorProto."; return false; @@ -1356,29 +1385,6 @@ class SumProcessor : public AgnosticNodeProcessor { } return false; } - - NodeDef* AddNodeReductionConst(const string& suffix, - const string& depended_node, - const string& device) { - auto const_node = GraphProcessor::AddNodeReductionConst( - strings::StrCat(kReductionConst, "-", suffix), device); - // This is to ensure the Sum node and the const node are in the - // same frame. - *const_node->add_input() = AsControlDependency(depended_node); - return const_node; - } - - string GetOrAddNodeReductionConst() { - string const_name; - if (is_in_frame_) { - auto const_node = AddNodeReductionConst( - node_->name(), NodeName(node_->input(0)), node_->device()); - const_name = const_node->name(); - } else { - const_name = kReductionConst; - } - return const_name; - } }; class DataLayoutOptimizer : GraphProcessor { @@ -1409,18 +1415,10 @@ class DataLayoutOptimizer : GraphProcessor { return AddNodePermConst(kPermNCHWToNHWC, "", {0, 2, 3, 1}); } - NodeDef* AddNodeConcatConst() { - return AddNodeConstScalar(kConcatConst, "", DT_INT32, 1); - } - NodeDef* AddNodeGatherAxisConst() { return AddNodeConstScalar(kGatherAxisConst, "", DT_INT32, 0); } - NodeDef* AddNodeReductionConst() { - return GraphProcessor::AddNodeReductionConst(kReductionConst, ""); - } - // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic. Status Expand() { int node_size_original = graph_->node_size(); @@ -1474,9 +1472,7 @@ class DataLayoutOptimizer : GraphProcessor { if (graph_->node_size() > node_size_original) { NodeDef* n = AddNodePermNHWCToNCHW(); n = AddNodePermNCHWToNHWC(); - n = AddNodeConcatConst(); n = AddNodeGatherAxisConst(); - n = AddNodeReductionConst(); std::set<string> ops_format_agnostic = GetOpsFormatAgnostic(); for (int i = 0; i < graph_->node_size(); i++) { if (ops_format_agnostic.find(graph_->node(i).op()) != @@ -1620,27 +1616,20 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, virtual_placer_.reset(new VirtualPlacer(cluster)); nodes_to_preserve_ = item.NodesToPreserve(); GraphProperties graph_properties(item); - auto status = graph_properties.InferStatically(); + auto status = graph_properties.InferStatically(false); if (!status.ok()) { *output = item.graph; return status; } TuningConfig config; - config.no_gemm = false; + config.no_gemm = true; + // TODO(yaozhang): Enable tuning with various TuningConfig choices wtih + // the measurement-based estimator. status = Tune(item, graph_properties, config, output); - // This is based on an empirical observation that if the introduced Transpose - // nodes is more than 30, not using GEMM implementation would result in better - // performance. - if (status.ok() && GetNumTranspose(*output) > 30) { - config.no_gemm = true; - status = Tune(item, graph_properties, config, output); - } - if (!status.ok()) { *output = item.graph; } - return status; } diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index 8c89f6744b..363b4c3fd8 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -298,6 +298,39 @@ TEST_F(LayoutOptimizerTest, Connectivity) { EXPECT_EQ(node_i2_output->input(0), "i1"); } +TEST_F(LayoutOptimizerTest, ConnectivityBinaryOpWithInputScalarAnd4D) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto i1 = ops::Identity(s.WithOpName("i1"), conv); + auto i2 = ops::Identity(s.WithOpName("i2"), i1); + auto scalar_sub = ops::Const(s.WithOpName("scalar_sub"), 3.0f, {}); + auto sub = ops::Sub(s.WithOpName("sub"), scalar_sub, i2); + auto i3 = ops::Identity(s.WithOpName("i3"), sub); + auto i4 = ops::Identity(s.WithOpName("i4"), i3); + auto i5 = ops::Identity(s.WithOpName("i5"), i4); + auto scalar_mul = ops::Const(s.WithOpName("scalar_mul"), 3.0f, {}); + auto mul = ops::Mul(s.WithOpName("mul"), scalar_mul, i5); + auto i6 = ops::Identity(s.WithOpName("i6"), mul); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + // Make the graph not in topological order to test the handling of multi-hop + // connectivity (here we say two nodes are connected if all nodes in the + // middle are layout agnostic). If the graph is already in topological order, + // the problem is easier, where layout optimizer only needs to check + // single-hop connectivity. + NodeMap node_map_original(&item.graph); + auto node_i1 = node_map_original.GetNode("i1"); + auto node_mul = node_map_original.GetNode("mul"); + node_mul->Swap(node_i1); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map_output(&output); + auto mul_node = node_map_output.GetNode("mul"); + EXPECT_EQ(mul_node->input(0), "scalar_mul"); + EXPECT_EQ(mul_node->input(1), "i5"); +} + TEST_F(LayoutOptimizerTest, PreserveFetch) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto conv = SimpleConv2D(&s, 3, 2, "VALID"); @@ -495,7 +528,175 @@ TEST_F(LayoutOptimizerTest, SplitNonConstDim) { auto split_node = node_map.GetNode("split"); EXPECT_EQ(split_node->input(0), "i1"); EXPECT_EQ(split_node->input(1), - "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-split"); + "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-split-1"); +} + +TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto axis = ops::Const(s.WithOpName("axis"), 3); + auto split = ops::Split(s.WithOpName("split"), axis, conv, 2); + auto concat = + ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis); + auto o = ops::Identity(s.WithOpName("o"), concat); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto concat_node = node_map.GetNode("concat"); + EXPECT_EQ(concat_node->input(0), "split:1"); + EXPECT_EQ(concat_node->input(1), "split:1"); + EXPECT_EQ(concat_node->input(2), "split:1"); + EXPECT_EQ(concat_node->input(3), "LayoutOptimizerConcatConst-concat"); + auto concat_dim = node_map.GetNode("LayoutOptimizerConcatConst-concat"); + EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1); +} + +TEST_F(LayoutOptimizerTest, Concat) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto axis = ops::Const(s.WithOpName("axis"), 3); + auto split = ops::Split(s.WithOpName("split"), axis, conv, 2); + auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis); + auto o = ops::Identity(s.WithOpName("o"), concat); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto concat_node = node_map.GetNode("concat"); + EXPECT_EQ(concat_node->input(0), "split"); + EXPECT_EQ(concat_node->input(1), "split:1"); + EXPECT_EQ(concat_node->input(2), "LayoutOptimizerConcatConst-concat"); + auto concat_dim = node_map.GetNode("LayoutOptimizerConcatConst-concat"); + EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1); +} + +TEST_F(LayoutOptimizerTest, Sum) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto reduction_indices = + ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3}); + auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices); + auto o = ops::Identity(s.WithOpName("o"), sum); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled + // because of the worse performance in some cases. + /* + NodeMap node_map(&output); + auto sum_node = node_map.GetNode("sum"); + EXPECT_EQ(sum_node->input(0), "Conv2D"); + EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices"); + auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices"); + Tensor tensor; + EXPECT_TRUE( + tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor())); + Tensor tensor_expected(DT_INT32, {3}); + test::FillValues<int>(&tensor_expected, {0, 2, 3}); + test::ExpectTensorEqual<int>(tensor_expected, tensor); + */ +} + +TEST_F(LayoutOptimizerTest, MulScalarAnd4D) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {}); + auto mul = ops::Mul(s.WithOpName("mul"), scalar, conv); + auto o = ops::Identity(s.WithOpName("o"), mul); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto mul_node = node_map.GetNode("mul"); + EXPECT_EQ(mul_node->input(0), "scalar"); + EXPECT_EQ(mul_node->input(1), "Conv2D"); +} + +TEST_F(LayoutOptimizerTest, Mul4DAndScalar) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {}); + auto mul = ops::Mul(s.WithOpName("mul"), conv, scalar); + auto o = ops::Identity(s.WithOpName("o"), mul); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto mul_node = node_map.GetNode("mul"); + EXPECT_EQ(mul_node->input(0), "Conv2D"); + EXPECT_EQ(mul_node->input(1), "scalar"); +} + +TEST_F(LayoutOptimizerTest, Mul4DAnd4D) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto i = ops::Identity(s.WithOpName("i"), conv); + auto mul = ops::Mul(s.WithOpName("mul"), conv, i); + auto o = ops::Identity(s.WithOpName("o"), mul); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto mul_node = node_map.GetNode("mul"); + EXPECT_EQ(mul_node->input(0), "Conv2D"); + EXPECT_EQ(mul_node->input(1), "i"); +} + +TEST_F(LayoutOptimizerTest, Mul4DAndVector) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2}); + auto mul = ops::Mul(s.WithOpName("mul"), conv, vector); + auto o = ops::Identity(s.WithOpName("o"), mul); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto mul_node = node_map.GetNode("mul"); + EXPECT_EQ(mul_node->input(0), "Conv2D"); + EXPECT_EQ(mul_node->input(1), "LayoutOptimizerReshapeNHWCToNCHW-mul-vector"); + auto mul_const = node_map.GetNode("LayoutOptimizerReshapeConst-mul-vector"); + Tensor tensor; + EXPECT_TRUE( + tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor())); + Tensor tensor_expected(DT_INT32, {4}); + test::FillValues<int>(&tensor_expected, {1, 2, 1, 1}); + test::ExpectTensorEqual<int>(tensor_expected, tensor); +} + +TEST_F(LayoutOptimizerTest, MulVectorAnd4D) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2}); + auto mul = ops::Mul(s.WithOpName("mul"), vector, conv); + auto o = ops::Identity(s.WithOpName("o"), mul); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto mul_node = node_map.GetNode("mul"); + // TODO(yaozhang): Support vector as the first input and 4d tensor as the + // second input for BinaryOpProcessor. + EXPECT_EQ(mul_node->input(0), "vector"); + EXPECT_EQ(mul_node->input(1), + "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-mul-1"); } } // namespace diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 7c44ce15c6..a2a2680c4f 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -716,7 +716,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, { // Estimate the size of the data to swap for each node. GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically()); + TF_RETURN_IF_ERROR(properties.InferStatically(true)); for (auto& swap : nodes_to_swap) { const NodeDef* node = swap.first; std::vector<OpInfo::TensorProperties> props = diff --git a/tensorflow/core/grappler/optimizers/static_schedule.cc b/tensorflow/core/grappler/optimizers/static_schedule.cc index 6ce6deef2c..450e853407 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule.cc +++ b/tensorflow/core/grappler/optimizers/static_schedule.cc @@ -86,7 +86,7 @@ Status EstimateEarliestExecutionTimes( name_map.clear(); GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically()); + TF_RETURN_IF_ERROR(properties.InferStatically(true)); OpLevelCostEstimator estimator; VirtualPlacer placer(cluster); @@ -154,7 +154,7 @@ Status EstimateRequiredTimes( } } GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically()); + TF_RETURN_IF_ERROR(properties.InferStatically(true)); OpLevelCostEstimator estimator; VirtualPlacer placer(cluster); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 21411097e8..dcffb28513 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3923,7 +3923,11 @@ tf_kernel_library( "scatter_nd_op.h", "scatter_nd_op_gpu.cu.cc", ], - deps = STATE_DEPS + [":dense_update_functor"], + deps = STATE_DEPS + [ + ":dense_update_functor", + ":training_op_helpers", + ":variable_ops", + ], ) tf_kernel_library( @@ -5833,11 +5837,11 @@ cc_library( srcs = ["dataset.cc"], hdrs = ["dataset.h"], deps = [ + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/util/tensor_bundle", ], ) @@ -6125,6 +6129,18 @@ tf_kernel_library( ) tf_kernel_library( + name = "random_dataset_op", + srcs = ["random_dataset_op.cc"], + deps = [ + ":dataset", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( name = "range_dataset_op", srcs = ["range_dataset_op.cc"], deps = [ @@ -6291,6 +6307,7 @@ tf_kernel_library( ":parallel_interleave_dataset_op", ":parallel_map_dataset_op", ":prefetch_dataset_op", + ":random_dataset_op", ":range_dataset_op", ":reader_dataset_ops", ":repeat_dataset_op", diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 3d2bb57aff..1791c51096 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -194,7 +194,23 @@ class Conv2DFastBackpropFilterOp : public OpKernel { context, (strides_[0] == 1 && strides_[3] == 1), errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + // TODO(yangzihao): Add a CPU implementation for dilated convolution. + OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), + errors::InvalidArgument( + "Current Eigen and libxsmm implementations do not " + "yet support dilation rates larger than 1.")); } void Compute(OpKernelContext* context) override { @@ -262,6 +278,7 @@ class Conv2DFastBackpropFilterOp : public OpKernel { } private: + std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; TensorFormat data_format_; @@ -290,7 +307,23 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { context, (strides_[0] == 1 && strides_[3] == 1), errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + // TODO(yangzihao): Add a CPU implementation for dilated convolution. + OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), + errors::InvalidArgument( + "Current libxsmm and customized CPU implementations do " + "not yet support dilation rates larger than 1.")); } void Compute(OpKernelContext* context) override { @@ -459,6 +492,7 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { } private: + std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; TensorFormat data_format_; @@ -510,10 +544,30 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); int stride_n = GetTensorDim(strides_, data_format_, 'N'); int stride_c = GetTensorDim(strides_, data_format_, 'C'); + int stride_h = GetTensorDim(strides_, data_format_, 'H'); + int stride_w = GetTensorDim(strides_, data_format_, 'W'); OP_REQUIRES( context, (stride_n == 1 && stride_c == 1), errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(context, stride_h > 0 && stride_w > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); + int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); + int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); + int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); + OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + OP_REQUIRES( + context, dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); use_cudnn_ &= CanUseCudnn(); cudnn_use_autotune_ = CudnnUseAutotune(); @@ -546,13 +600,16 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // do not support striding on the batch or depth dimension). const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); + const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); + const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input, - stride_rows, stride_cols, padding_, filter_backprop, - data_format_); + dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, + filter_backprop, data_format_); } private: + std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; bool use_cudnn_; @@ -566,38 +623,46 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { template <typename T> void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()( OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& out_backprop, const Tensor& input, int row_stride, - int col_stride, const Padding& padding, Tensor* filter_backprop, - TensorFormat data_format) { + const Tensor& out_backprop, const Tensor& input, int row_dilation, + int col_dilation, int row_stride, int col_stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format) { using perftools::gputools::dnn::AlgorithmConfig; using perftools::gputools::dnn::AlgorithmDesc; using perftools::gputools::dnn::ProfileResult; + std::vector<int32> dilations(4, 1); + dilations[GetTensorDimIndex(data_format, 'H')] = row_dilation; + dilations[GetTensorDimIndex(data_format, 'W')] = col_dilation; + std::vector<int32> strides(4, 1); strides[GetTensorDimIndex(data_format, 'H')] = row_stride; strides[GetTensorDimIndex(data_format, 'W')] = col_stride; TensorShape filter_shape = filter_backprop->shape(); ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( + OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensionsV2( "Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2, input.shape(), filter_shape, out_backprop.shape(), - strides, padding, data_format, &dims)); + dilations, strides, padding, data_format, &dims)); + // TODO(yangzihao): The padding computations should be done in + // GetWindowedOutputSize() functions. const int padding_rows = (padding == VALID) ? 0 : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) * dims.spatial_dims[0].stride + - dims.spatial_dims[0].filter_size - - dims.spatial_dims[0].input_size); + (dims.spatial_dims[0].filter_size - 1) * + dims.spatial_dims[0].dilation + + 1 - dims.spatial_dims[0].input_size); const int padding_cols = (padding == VALID) ? 0 : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) * dims.spatial_dims[1].stride + - dims.spatial_dims[1].filter_size - - dims.spatial_dims[1].input_size); + (dims.spatial_dims[1].filter_size - 1) * + dims.spatial_dims[1].dilation + + 1 - dims.spatial_dims[1].input_size); // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only // calling it when that is true. Remove this check when (if?) cuDNN starts @@ -730,7 +795,9 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()( .set_input_feature_map_count(dims.in_depth) .set_output_feature_map_count(dims.out_depth); perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride) + conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation) + .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation) + .set_vertical_filter_stride(dims.spatial_dims[0].stride) .set_horizontal_filter_stride(dims.spatial_dims[1].stride) .set_zero_padding_height(padding_rows / 2) .set_zero_padding_width(padding_cols / 2); @@ -821,6 +888,8 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()( dims.out_depth, // out_depths {{dims.spatial_dims[0].filter_size, // filter_rows dims.spatial_dims[1].filter_size}}, // filter_cols + {{dims.spatial_dims[0].dilation, // dilation_rows + dims.spatial_dims[1].dilation}}, // dilation_cols {{dims.spatial_dims[0].stride, // stride_rows dims.spatial_dims[1].stride}}, // stride_cols {{padding_rows, // padding_rows diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index d28f6b4d10..736241a029 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -198,7 +198,23 @@ class Conv2DFastBackpropInputOp : public OpKernel { context, (strides_[0] == 1 && strides_[3] == 1), errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + OP_REQUIRES(context, (dilations_[0] && dilations_[3]), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + // TODO(yangzihao): Add a CPU implementation for dilated convolution. + OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), + errors::InvalidArgument( + "Current Eigen and libxsmm implementations do not " + "yet support dilation rates larger than 1.")); } void Compute(OpKernelContext* context) override { @@ -268,6 +284,7 @@ class Conv2DFastBackpropInputOp : public OpKernel { } private: + std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; TensorFormat data_format_; @@ -296,7 +313,23 @@ class Conv2DCustomBackpropInputOp : public OpKernel { context, (strides_[0] == 1 && strides_[3] == 1), errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + // TODO(yangzihao): Add a CPU implementation for dilated convolution. + OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), + errors::InvalidArgument( + "Current libxsmm and customized CPU implementations do " + "not yet support dilation rates larger than 1.")); } void Compute(OpKernelContext* context) override { @@ -532,6 +565,7 @@ class Conv2DCustomBackpropInputOp : public OpKernel { } private: + std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; TensorFormat data_format_; @@ -586,10 +620,30 @@ class Conv2DSlowBackpropInputOp : public OpKernel { "specify 4 dimensions")); int stride_n = GetTensorDim(strides_, data_format_, 'N'); int stride_c = GetTensorDim(strides_, data_format_, 'C'); + int stride_h = GetTensorDim(strides_, data_format_, 'H'); + int stride_w = GetTensorDim(strides_, data_format_, 'W'); OP_REQUIRES( context, (stride_n == 1 && stride_c == 1), errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(context, stride_h > 0 && stride_w > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); + int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); + int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); + int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); + OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + OP_REQUIRES( + context, dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); use_cudnn_ &= CanUseCudnn(); cudnn_use_autotune_ = CudnnUseAutotune(); @@ -622,12 +676,16 @@ class Conv2DSlowBackpropInputOp : public OpKernel { // do not support striding on the batch or depth dimension). const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); + const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); + const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter, - stride_rows, stride_cols, padding_, in_backprop, data_format_); + dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, + in_backprop, data_format_); } private: + std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; bool use_cudnn_; @@ -641,39 +699,48 @@ class Conv2DSlowBackpropInputOp : public OpKernel { template <typename T> void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()( OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& out_backprop, const Tensor& filter, int row_stride, - int col_stride, const Padding& padding, Tensor* in_backprop, - TensorFormat data_format) { + const Tensor& out_backprop, const Tensor& filter, int row_dilation, + int col_dilation, int row_stride, int col_stride, const Padding& padding, + Tensor* in_backprop, TensorFormat data_format) { using perftools::gputools::dnn::AlgorithmConfig; using perftools::gputools::dnn::AlgorithmDesc; using perftools::gputools::dnn::ProfileResult; std::vector<int32> strides(4, 1); - strides[GetTensorDimIndex(data_format, 'H')] = row_stride; - strides[GetTensorDimIndex(data_format, 'W')] = col_stride; + std::vector<int32> dilations(4, 1); + auto input_h = GetTensorDimIndex(data_format, 'H'); + auto input_w = GetTensorDimIndex(data_format, 'W'); + strides[input_h] = row_stride; + strides[input_w] = col_stride; + dilations[input_h] = row_dilation; + dilations[input_w] = col_dilation; TensorShape input_shape = in_backprop->shape(); const TensorShape& filter_shape = filter.shape(); ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( + OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensionsV2( "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, input_shape, filter_shape, out_backprop.shape(), - strides, padding, data_format, &dims)); + dilations, strides, padding, data_format, &dims)); + // TODO(yangzihao): The padding computations should be done in + // GetWindowedOutputSize() functions. const int padding_rows = (padding == VALID) ? 0 : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) * dims.spatial_dims[0].stride + - dims.spatial_dims[0].filter_size - - dims.spatial_dims[0].input_size); + (dims.spatial_dims[0].filter_size - 1) * + dims.spatial_dims[0].dilation + + 1 - dims.spatial_dims[0].input_size); const int padding_cols = (padding == VALID) ? 0 : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) * dims.spatial_dims[1].stride + - dims.spatial_dims[1].filter_size - - dims.spatial_dims[1].input_size); + (dims.spatial_dims[1].filter_size - 1) * + dims.spatial_dims[1].dilation + + 1 - dims.spatial_dims[1].input_size); // TODO(keveman): cuDNN only supports equal padding on both sides, so only // calling it when that is true. Remove this check when (if?) cuDNN starts @@ -789,7 +856,9 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()( .set_input_feature_map_count(dims.in_depth) .set_output_feature_map_count(dims.out_depth); perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride) + conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation) + .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation) + .set_vertical_filter_stride(dims.spatial_dims[0].stride) .set_horizontal_filter_stride(dims.spatial_dims[1].stride) .set_zero_padding_height(padding_rows / 2) .set_zero_padding_width(padding_cols / 2); @@ -875,6 +944,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()( dims.out_depth, // out_depths {{dims.spatial_dims[0].filter_size, // filter_rows dims.spatial_dims[1].filter_size}}, // filter_cols + {{dims.spatial_dims[0].dilation, // dilation_rows + dims.spatial_dims[1].dilation}}, // dilation_cols {{dims.spatial_dims[0].stride, // stride_rows dims.spatial_dims[1].stride}}, // stride_cols {{padding_rows, // padding_rows diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h index e068fb8684..535586d53a 100644 --- a/tensorflow/core/kernels/conv_grad_ops.h +++ b/tensorflow/core/kernels/conv_grad_ops.h @@ -175,15 +175,17 @@ template <typename Device, typename T> struct LaunchConv2DBackpropInputOp { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Tensor& out_backprop, const Tensor& filter, - int row_stride, int col_stride, const Padding& padding, - Tensor* in_backprop, TensorFormat data_format); + int row_dilation, int col_dilation, int row_stride, + int col_stride, const Padding& padding, Tensor* in_backprop, + TensorFormat data_format); }; template <typename Device, typename T> struct LaunchConv2DBackpropFilterOp { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Tensor& out_backprop, const Tensor& input, - int row_stride, int col_stride, const Padding& padding, + int row_dilation, int col_dilation, int row_stride, + int col_stride, const Padding& padding, Tensor* filter_backprop, TensorFormat data_format); }; @@ -191,8 +193,9 @@ struct LaunchConv2DBackpropFilterOp { template <typename T> struct LaunchConv2DBackpropInputOp<Eigen::GpuDevice, T> { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Padding& padding, Tensor* output, + const Tensor& input, const Tensor& filter, int row_dilation, + int col_dilation, int row_stride, int col_stride, + const Padding& padding, Tensor* output, TensorFormat data_format); }; @@ -200,7 +203,8 @@ template <typename T> struct LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T> { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Tensor& out_backprop, const Tensor& input, - int row_stride, int col_stride, const Padding& padding, + int row_dilation, int col_dilation, int row_stride, + int col_stride, const Padding& padding, Tensor* filter_backprop, TensorFormat data_format); }; #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index c2d24d1f12..4d0f1ab317 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -645,6 +645,9 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { {{input_size[0], input_size[1], input_size[2]}}, out_depth, {{filter_size[0], filter_size[1], filter_size[2]}}, + // TODO(yangzihao): Send in arbitrary dilation rates after the dilated + // conv is supported. + /*dilations=*/{{1, 1, 1}}, {{strides[0], strides[1], strides[2]}}, {{padding_planes, padding_rows, padding_cols}}, dtype, @@ -1011,6 +1014,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { {{input_size[0], input_size[1], input_size[2]}}, out_depth, {{filter_size[0], filter_size[1], filter_size[2]}}, + {{1, 1, 1}}, {{strides[0], strides[1], strides[2]}}, {{padding_planes, padding_rows, padding_cols}}, dtype, diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index bb67113fb0..ba40c428e4 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -112,7 +112,8 @@ struct LaunchGeneric { template <typename T> struct LaunchConv2DOp<CPUDevice, T> { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, + const Tensor& input, const Tensor& filter, + int /*row_dilation*/, int /*col_dilation*/, int row_stride, int col_stride, const Padding& padding, Tensor* output, TensorFormat data_format) { if (data_format != FORMAT_NHWC) { @@ -133,8 +134,10 @@ class LaunchDeepConvOp { const Tensor& filter, int batch, int input_rows, int input_cols, int in_depth, int filter_rows, int filter_cols, int pad_rows, int pad_cols, int out_rows, - int out_cols, int out_depth, int stride_rows, int stride_cols, - Tensor* output, TensorFormat data_format) { + int /*out_cols*/, int /*out_depth*/, int /*dilation_rows*/, + int /*dilation_cols*/, int /*stride_rows*/, + int /*stride_cols*/, Tensor* /*output*/, + TensorFormat /*data_format*/) { return false; } }; @@ -147,9 +150,11 @@ class LaunchDeepConvOp<CPUDevice, float> { const Tensor& filter, int batch, int input_rows, int input_cols, int in_depth, int filter_rows, int filter_cols, int pad_rows, int pad_cols, int out_rows, - int out_cols, int out_depth, int stride_rows, int stride_cols, + int out_cols, int out_depth, int dilation_rows, + int dilation_cols, int stride_rows, int stride_cols, Tensor* output, TensorFormat data_format) { - if (data_format != FORMAT_NHWC || + if (data_format != FORMAT_NHWC || dilation_rows != 1 || + dilation_cols != 1 || !CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols, in_depth, out_depth, out_rows, out_cols)) { return false; @@ -187,7 +192,8 @@ class LaunchXsmmConvOp { int input_cols, int in_depth, int filter_rows, int filter_cols, int pad_rows, int pad_cols, int out_rows, int out_cols, int out_depth, int stride_rows, int stride_cols, - Tensor* output, TensorFormat data_format) { + int dilation_rows, int dilation_cols, Tensor* output, + TensorFormat data_format) { return false; } }; @@ -199,7 +205,8 @@ class LaunchXsmmConvOp<CPUDevice, float> { const Tensor& filter, int batch, int input_rows, int input_cols, int in_depth, int filter_rows, int filter_cols, int pad_rows, int pad_cols, int out_rows, - int out_cols, int out_depth, int stride_rows, int stride_cols, + int out_cols, int out_depth, int dilation_rows, + int dilation_cols, int stride_rows, int stride_cols, Tensor* output, TensorFormat data_format) { auto num_threads = ctx->device()->tensorflow_cpu_worker_threads()->num_threads; @@ -228,11 +235,8 @@ class LaunchXsmmConvOp<CPUDevice, float> { desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE; desc.datatype = LIBXSMM_DNN_DATATYPE_F32; - if (!CanUseXsmmConv2D(desc, data_format)) { - return false; - } - - if (!CanUseXsmmConv2D(desc, data_format)) { + if (dilation_rows != 1 || dilation_cols != 1 || + !CanUseXsmmConv2D(desc, data_format)) { return false; } @@ -251,6 +255,7 @@ template <typename Device, typename T> class Conv2DOp : public BinaryOp<T> { public: explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) { + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); string data_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); @@ -259,15 +264,35 @@ class Conv2DOp : public BinaryOp<T> { OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); use_cudnn_ &= CanUseCudnn(); cudnn_use_autotune_ = CudnnUseAutotune(); + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); OP_REQUIRES(context, strides_.size() == 4, errors::InvalidArgument("Sliding window strides field must " "specify 4 dimensions")); const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); + const int64 stride_h = GetTensorDim(strides_, data_format_, 'H'); + const int64 stride_w = GetTensorDim(strides_, data_format_, 'W'); OP_REQUIRES( context, stride_n == 1 && stride_c == 1, errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(context, stride_h > 0 && stride_w > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); + + const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); + const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); + const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); + const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); + OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + OP_REQUIRES( + context, dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } @@ -334,18 +359,22 @@ class Conv2DOp : public BinaryOp<T> { errors::InvalidArgument("batch is too large")); const int batch = static_cast<int>(batch_raw); - // For now we take the stride from the second and third dimensions only (we - // do not support striding on the batch or depth dimension). + // For now we take the stride and dilation from the second and third + // dimensions only (we do not support striding or dilation on the batch or + // depth dimension). const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); + const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); + const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); + int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_rows, filter_rows, stride_rows, - padding_, &out_rows, &pad_rows)); - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_cols, filter_cols, stride_cols, - padding_, &out_cols, &pad_cols)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeV2( + input_rows, filter_rows, dilation_rows, + stride_rows, padding_, &out_rows, &pad_rows)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeV2( + input_cols, filter_cols, dilation_cols, + stride_cols, padding_, &out_cols, &pad_cols)); TensorShape out_shape = ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); @@ -361,6 +390,8 @@ class Conv2DOp : public BinaryOp<T> { << ", filter_rows = " << filter_rows << ", stride_rows = " << stride_rows << ", stride_cols = " << stride_cols + << ", dilation_rows = " << dilation_rows + << ", dilation_cols = " << dilation_cols << ", out_depth = " << out_depth; // If there is nothing to compute, return. @@ -372,7 +403,8 @@ class Conv2DOp : public BinaryOp<T> { if (LaunchXsmmConvOp<Device, T>::Run( context, input, filter, batch, input_rows, input_cols, in_depth, filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols, - out_depth, stride_rows, stride_cols, output, data_format_)) { + out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols, + output, data_format_)) { return; } #endif @@ -380,15 +412,18 @@ class Conv2DOp : public BinaryOp<T> { if (LaunchDeepConvOp<Device, T>::Run( context, input, filter, batch, input_rows, input_cols, in_depth, filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols, - out_depth, stride_rows, stride_cols, output, data_format_)) { + out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols, + output, data_format_)) { return; } launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, - stride_rows, stride_cols, padding_, output, data_format_); + dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, + output, data_format_); } private: + std::vector<int32> dilations_; std::vector<int32> strides_; bool use_cudnn_; Padding padding_; @@ -443,9 +478,9 @@ typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters, template <typename T> void LaunchConv2DOp<GPUDevice, T>::operator()( OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input_param, const Tensor& filter, int row_stride, - int col_stride, const Padding& padding, Tensor* output, - TensorFormat data_format) { + const Tensor& input_param, const Tensor& filter, int row_dilation, + int col_dilation, int row_stride, int col_stride, const Padding& padding, + Tensor* output, TensorFormat data_format) { using perftools::gputools::dnn::AlgorithmConfig; using perftools::gputools::dnn::AlgorithmDesc; using perftools::gputools::dnn::ProfileResult; @@ -461,8 +496,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()( Tensor input = input_param; - if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && - col_stride == 1 && data_format == FORMAT_NHWC) { + if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_dilation == 1 && + col_dilation == 1 && row_stride == 1 && col_stride == 1 && + data_format == FORMAT_NHWC) { // 1x1 filter, so call cublas directly. const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2); const uint64 k = filter.dim_size(2); @@ -487,7 +523,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()( } return; } else if (filter.dim_size(0) == input.dim_size(1) && - filter.dim_size(1) == input.dim_size(2) && padding == VALID && + filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 && + col_dilation == 1 && padding == VALID && data_format == FORMAT_NHWC) { // The input data and filter have the same height/width, so call cublas // directly. @@ -530,17 +567,19 @@ void LaunchConv2DOp<GPUDevice, T>::operator()( const int64 patch_cols = filter.dim_size(1); if (padding == SAME) { // Total padding on rows and cols is - // Pr = (R' - 1) * S + Kr - R - // Pc = (C' - 1) * S + Kc - C + // Pr = (R' - 1) * S + (Kr - 1) * Dr + 1 - R + // Pc = (C' - 1) * S + (Kc - 1) * Dc + 1 - C // where (R', C') are output dimensions, (R, C) are input dimensions, S - // is stride, (Kr, Kc) are filter dimensions. + // is stride, (Dr, Dc) are dilations, (Kr, Kc) are filter dimensions. // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means // we pad more on the right and bottom than on the top and left. padding_rows = - std::max<int>(0, (out_rows - 1) * row_stride + patch_rows - in_rows); + std::max<int>(0, (out_rows - 1) * row_stride + + (patch_rows - 1) * row_dilation + 1 - in_rows); padding_cols = - std::max<int>(0, (out_cols - 1) * col_stride + patch_cols - in_cols); + std::max<int>(0, (out_cols - 1) * col_stride + + (patch_cols - 1) * col_dilation + 1 - in_cols); const bool rows_odd = (padding_rows % 2 != 0); const bool cols_odd = (padding_cols % 2 != 0); if (rows_odd || cols_odd) { @@ -605,7 +644,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()( .set_input_feature_map_count(filter.dim_size(2)) .set_output_feature_map_count(filter.dim_size(3)); perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(row_stride) + conv_desc.set_vertical_dilation_rate(row_dilation) + .set_horizontal_dilation_rate(col_dilation) + .set_vertical_filter_stride(row_stride) .set_horizontal_filter_stride(col_stride) .set_zero_padding_height(padding_rows / 2) .set_zero_padding_width(padding_cols / 2); @@ -652,6 +693,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()( out_depths, // out_depths {{patch_rows, // filter_rows patch_cols}}, // filter_cols + {{row_dilation, // dilation_rows + col_dilation}}, // dilation_cols {{row_stride, // stride_rows col_stride}}, // stride_cols {{padding_rows, // padding_rows diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h index e29271dff2..09a3b78776 100644 --- a/tensorflow/core/kernels/conv_ops.h +++ b/tensorflow/core/kernels/conv_ops.h @@ -34,8 +34,9 @@ class OpKernelContext; template <typename Device, typename T> struct LaunchConv2DOp { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Padding& padding, Tensor* output, + const Tensor& input, const Tensor& filter, int row_dilation, + int col_dilation, int row_stride, int col_stride, + const Padding& padding, Tensor* output, TensorFormat data_format); }; @@ -43,8 +44,9 @@ struct LaunchConv2DOp { template <typename T> struct LaunchConv2DOp<Eigen::GpuDevice, T> { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Padding& padding, Tensor* output, + const Tensor& input, const Tensor& filter, int row_dilation, + int col_dilation, int row_stride, int col_stride, + const Padding& padding, Tensor* output, TensorFormat data_format); }; #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index 37cb67bc51..39202d7334 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -377,6 +377,9 @@ struct LaunchConvOp<GPUDevice, T> { {{in_planes, in_rows, in_cols}}, out_depth, {{filter_planes, filter_rows, filter_cols}}, + // TODO(yangzihao): Send in arbitrary dilation rates after the dilated + // conv is supported. + /*dilations=*/{{1, 1, 1}}, {{strides[0], strides[1], strides[2]}}, {{pad_planes, pad_rows, pad_cols}}, dtype, diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index c852dc9991..6f82698596 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -91,13 +91,14 @@ class ConvParameters { using SpatialArray = gtl::InlinedVector<int64, 3>; ConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, int64 out_depths, const SpatialArray& filter, - const SpatialArray& stride, const SpatialArray& padding, - DataType dtype, int device_id) + const SpatialArray& dilation, const SpatialArray& stride, + const SpatialArray& padding, DataType dtype, int device_id) : batch_(batch), in_depths_(in_depths), out_depths_(out_depths), in_(in), filter_(filter), + dilation_(dilation), stride_(stride), padding_(padding), dtype_(dtype), @@ -107,6 +108,7 @@ class ConvParameters { for (int64 val : in) hash_code_ = Hash64Combine(hash_code_, val); hash_code_ = Hash64Combine(hash_code_, out_depths); for (int64 val : filter) hash_code_ = Hash64Combine(hash_code_, val); + for (int64 val : dilation) hash_code_ = Hash64Combine(hash_code_, val); for (int64 val : stride) hash_code_ = Hash64Combine(hash_code_, val); for (int64 val : padding) hash_code_ = Hash64Combine(hash_code_, val); hash_code_ = Hash64Combine(hash_code_, dtype); @@ -128,6 +130,7 @@ class ConvParameters { "(", str_util::Join(in_, ", "), "), ", out_depths_, ", ", "(", str_util::Join(filter_, ", "), "), ", + "(", str_util::Join(dilation_, ", "), "), ", "(", str_util::Join(stride_, ", "), "), ", "(", str_util::Join(padding_, ", "), "), ", dtype_, ", ", @@ -154,11 +157,11 @@ class ConvParameters { protected: using ParameterDataType = std::tuple<int64, int64, SpatialArray, int64, SpatialArray, SpatialArray, - SpatialArray, DataType, int>; + SpatialArray, SpatialArray, DataType, int>; ParameterDataType get_data_as_tuple() const { return std::make_tuple(batch_, in_depths_, in_, out_depths_, filter_, - stride_, padding_, dtype_, device_id_); + dilation_, stride_, padding_, dtype_, device_id_); } uint64 hash_code_; @@ -169,6 +172,7 @@ class ConvParameters { int64 out_depths_; SpatialArray in_; SpatialArray filter_; + SpatialArray dilation_; SpatialArray stride_; SpatialArray padding_; DataType dtype_; diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc index ea54d6cf6c..666bca265c 100644 --- a/tensorflow/core/kernels/conv_ops_test.cc +++ b/tensorflow/core/kernels/conv_ops_test.cc @@ -43,6 +43,8 @@ TEST(ConvParameters, WinogradNonfusedAlgoSize) { 128, // out_depths {{3, // filter_rows 3}}, // filter_cols + {{1, // dilation_rows + 1}}, // dilation_cols {{1, // stride_rows 1}}, // stride_cols {{0, // padding_rows @@ -60,6 +62,8 @@ TEST(ConvParameters, WinogradNonfusedAlgoSize) { 768, // out_depths {{3, // filter_rows 3}}, // filter_cols + {{1, // dilation_rows + 1}}, // dilation_cols {{1, // stride_rows 1}}, // stride_cols {{0, // padding_rows diff --git a/tensorflow/core/kernels/cwise_op_asinh.cc b/tensorflow/core/kernels/cwise_op_asinh.cc index e6e1b83b30..0aec6aac34 100644 --- a/tensorflow/core/kernels/cwise_op_asinh.cc +++ b/tensorflow/core/kernels/cwise_op_asinh.cc @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tensorflow/core/kernels/dataset.cc b/tensorflow/core/kernels/dataset.cc index fcfa2956f7..0972129787 100644 --- a/tensorflow/core/kernels/dataset.cc +++ b/tensorflow/core/kernels/dataset.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/kernels/dataset.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/node_builder.h" + namespace tensorflow { namespace { @@ -70,6 +73,143 @@ class DatasetVariantWrapper { } // namespace +Status GraphDefBuilderWrapper::AddDataset( + const GraphDatasetBase* dataset, + const std::vector<std::pair<size_t, Node*>>& inputs, + const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, + const std::vector<std::pair<StringPiece, AttrValue>>& attrs, + Node** output) { + const string& op_type_name = dataset->op_name(); + std::unique_ptr<const GraphDefBuilder::Options> opts( + new GraphDefBuilder::Options(b_->opts())); + // TODO(srbs|mrry): Not all datasets have output_types and output_shapes + // attributes defined. It will be nice to have a consistent pattern. + bool has_output_types_attr = HasAttr(op_type_name, "output_types"); + bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes"); + if (has_output_shapes_attr) { + opts.reset(new GraphDefBuilder::Options( + opts->WithAttr("output_shapes", dataset->output_shapes()))); + } + if (has_output_types_attr) { + opts.reset(new GraphDefBuilder::Options( + opts->WithAttr("output_types", dataset->output_dtypes()))); + } + for (auto attr : attrs) { + opts.reset( + new GraphDefBuilder::Options(opts->WithAttr(attr.first, attr.second))); + } + if (opts->HaveError()) { + return errors::Internal("AddDataset: Failed to build Options with error ", + opts->StatusToString()); + } + NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name, + opts->op_registry()); + { + size_t total_size = inputs.size() + list_inputs.size(); + auto inputs_iter = inputs.begin(); + auto list_inputs_iter = list_inputs.begin(); + for (int i = 0; i < total_size; i++) { + if (inputs_iter != inputs.end() && inputs_iter->first == i) { + node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second)); + inputs_iter++; + } else if (list_inputs_iter != list_inputs.end() && + list_inputs_iter->first == i) { + std::vector<NodeBuilder::NodeOut> nodeout_inputs; + nodeout_inputs.reserve(list_inputs_iter->second.size()); + for (Node* n : list_inputs_iter->second) { + nodeout_inputs.emplace_back(n); + } + node_builder.Input(nodeout_inputs); + list_inputs_iter++; + } else { + return errors::InvalidArgument("No input found for index ", i); + } + } + } + *output = opts->FinalizeBuilder(&node_builder); + if (*output == nullptr) { + return errors::Internal("AddDataset: Failed to build ", op_type_name, + " op with error ", opts->StatusToString()); + } + return Status::OK(); +} + +Status GraphDefBuilderWrapper::AddFunction(OpKernelContext* ctx, + const string& function_name) { + if (b_->HasFunction(function_name)) { + LOG(INFO) << "Function with name " << function_name << "already exists in" + << " the graph. It will not be added again."; + return Status::OK(); + } + TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name)); + const FunctionLibraryDefinition* flib_def = + ctx->function_library()->GetFunctionLibraryDefinition(); + const FunctionDef* f_def = flib_def->Find(function_name); + if (f_def == nullptr) { + return errors::InvalidArgument("Unable to find FunctionDef for ", + function_name, " in the registry."); + } + FunctionDefLibrary def; + *def.add_function() = *f_def; + const string gradient_func = flib_def->FindGradient(function_name); + if (!gradient_func.empty()) { + GradientDef* g_def = def.add_gradient(); + g_def->set_function_name(function_name); + g_def->set_gradient_func(gradient_func); + } + TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def)); + + // Recursively add functions in inputs of function_name. + for (const NodeDef& node_def : f_def->node_def()) { + const OpRegistrationData* op_reg_data = nullptr; + TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data)); + if (op_reg_data->is_function_op) { + TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name())); + } + // Recursively add functions in attrs of this NodeDef. + for (const auto& pair : node_def.attr()) { + TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx)); + } + } + + // Recursively add functions in attrs of function_name. + for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) { + TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx)); + } + return Status::OK(); +} + +void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val, + Node** output) { + *output = ops::SourceOp( + "Const", + b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val)); +} + +bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name, + const string& attr_name) const { + const OpDef* op_def = nullptr; + Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def); + if (!s.ok() || op_def == nullptr) { + return false; + } + return HasAttr(op_def, attr_name); +} + +Status GraphDatasetBase::Serialize(OpKernelContext* ctx, + string* serialized_graph_def, + string* output_node) const { + GraphDefBuilder b; + DatasetGraphDefBuilder db(&b); + Node* node = nullptr; + TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node)); + *output_node = node->name(); + GraphDef graph_def; + TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); + graph_def.SerializeToString(serialized_graph_def); + return Status::OK(); +} + Status GetDatasetFromVariantTensor(const Tensor& tensor, DatasetBase** out_dataset) { if (!(tensor.dtype() == DT_VARIANT || diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h index afbebb0692..504a88a309 100644 --- a/tensorflow/core/kernels/dataset.h +++ b/tensorflow/core/kernels/dataset.h @@ -19,12 +19,13 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" -#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/tracing.h" @@ -59,6 +60,12 @@ class IteratorStateWriter { virtual ~IteratorStateWriter() {} }; +// Forward declarations to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; +class GraphDatasetBase; +class Node; + // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. class GraphDefBuilderWrapper { public: @@ -110,10 +117,8 @@ class GraphDefBuilderWrapper { return Status::OK(); } - template <class DatasetType> - Status AddDataset(const DatasetType* dataset, - const std::vector<NodeBuilder::NodeOut>& inputs, - Node** output) { + Status AddDataset(const GraphDatasetBase* dataset, + const std::vector<Node*>& inputs, Node** output) { return AddDataset(dataset, inputs, {}, output); } @@ -125,77 +130,23 @@ class GraphDefBuilderWrapper { // `*output` contains a pointer to the output `Node`. It is guaranteed to be // non-null if the method returns with an OK status. // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - template <class DatasetType> - Status AddDataset(const DatasetType* dataset, - const std::vector<NodeBuilder::NodeOut>& inputs, + Status AddDataset(const GraphDatasetBase* dataset, + const std::vector<Node*>& inputs, const std::vector<std::pair<StringPiece, AttrValue>>& attrs, Node** output) { - std::vector<std::pair<size_t, NodeBuilder::NodeOut>> enumerated_inputs( - inputs.size()); + std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size()); for (int i = 0; i < inputs.size(); i++) { enumerated_inputs[i] = std::make_pair(i, inputs[i]); } return AddDataset(dataset, enumerated_inputs, {}, attrs, output); } - template <class DatasetType> Status AddDataset( - const DatasetType* dataset, - const std::vector<std::pair<size_t, NodeBuilder::NodeOut>>& inputs, - const std::vector< - std::pair<size_t, gtl::ArraySlice<NodeBuilder::NodeOut>>>& - list_inputs, + const GraphDatasetBase* dataset, + const std::vector<std::pair<size_t, Node*>>& inputs, + const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, const std::vector<std::pair<StringPiece, AttrValue>>& attrs, - Node** output) { - const string& op_type_name = dataset->op_name(); - std::unique_ptr<const GraphDefBuilder::Options> opts( - new GraphDefBuilder::Options(b_->opts())); - // TODO(srbs|mrry): Not all datasets have output_types and output_shapes - // attributes defined. It will be nice to have a consistent pattern. - bool has_output_types_attr = HasAttr(op_type_name, "output_types"); - bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes"); - if (has_output_shapes_attr) { - opts.reset(new GraphDefBuilder::Options( - opts->WithAttr("output_shapes", dataset->output_shapes()))); - } - if (has_output_types_attr) { - opts.reset(new GraphDefBuilder::Options( - opts->WithAttr("output_types", dataset->output_dtypes()))); - } - for (auto attr : attrs) { - opts.reset(new GraphDefBuilder::Options( - opts->WithAttr(attr.first, attr.second))); - } - if (opts->HaveError()) { - return errors::Internal("AddDataset: Failed to build Options with error ", - opts->StatusToString()); - } - NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name, - opts->op_registry()); - { - size_t total_size = inputs.size() + list_inputs.size(); - auto inputs_iter = inputs.begin(); - auto list_inputs_iter = list_inputs.begin(); - for (int i = 0; i < total_size; i++) { - if (inputs_iter != inputs.end() && inputs_iter->first == i) { - node_builder.Input(inputs_iter->second); - inputs_iter++; - } else if (list_inputs_iter != list_inputs.end() && - list_inputs_iter->first == i) { - node_builder.Input(list_inputs_iter->second); - list_inputs_iter++; - } else { - return errors::InvalidArgument("No input found for index ", i); - } - } - } - *output = opts->FinalizeBuilder(&node_builder); - if (*output == nullptr) { - return errors::Internal("AddDataset: Failed to build ", op_type_name, - " op with error ", opts->StatusToString()); - } - return Status::OK(); - } + Node** output); // Adds a user-defined function with name `function_name` to the graph and // recursively adds all functions it references. If a function with a matching @@ -203,50 +154,7 @@ class GraphDefBuilderWrapper { // name `function_name` is not found in the FunctionLibraryDefinition, returns // an InvalidArgumentError. If the function with name `function_name` or any // of its dependent functions are stateful, returns an InvalidArgument error. - Status AddFunction(OpKernelContext* ctx, const string& function_name) { - if (b_->HasFunction(function_name)) { - LOG(INFO) << "Function with name " << function_name << "already exists in" - << " the graph. It will not be added again."; - return Status::OK(); - } - TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name)); - const FunctionLibraryDefinition* flib_def = - ctx->function_library()->GetFunctionLibraryDefinition(); - const FunctionDef* f_def = flib_def->Find(function_name); - if (f_def == nullptr) { - return errors::InvalidArgument("Unable to find FunctionDef for ", - function_name, " in the registry."); - } - FunctionDefLibrary def; - *def.add_function() = *f_def; - const string gradient_func = flib_def->FindGradient(function_name); - if (!gradient_func.empty()) { - GradientDef* g_def = def.add_gradient(); - g_def->set_function_name(function_name); - g_def->set_gradient_func(gradient_func); - } - TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def)); - - // Recursively add functions in inputs of function_name. - for (const NodeDef& node_def : f_def->node_def()) { - const OpRegistrationData* op_reg_data = nullptr; - TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data)); - if (op_reg_data->is_function_op) { - TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name())); - } - // Recursively add functions in attrs of this NodeDef. - for (const auto& pair : node_def.attr()) { - TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx)); - } - } - - // Recursively add functions in attrs of function_name. - for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); - iter++) { - TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx)); - } - return Status::OK(); - } + Status AddFunction(OpKernelContext* ctx, const string& function_name); template <typename T> void BuildAttrValue(const T& value, AttrValue* attr) { @@ -254,11 +162,7 @@ class GraphDefBuilderWrapper { } private: - void AddTensorInternal(const Tensor& val, Node** output) { - *output = ops::SourceOp( - "Const", - b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val)); - } + void AddTensorInternal(const Tensor& val, Node** output); Status EnsureFunctionIsStateless(OpKernelContext* ctx, const string& function_name) const { @@ -294,14 +198,7 @@ class GraphDefBuilderWrapper { HasAttr(op_def, "output_shapes"); } - bool HasAttr(const string& op_type_name, const string& attr_name) const { - const OpDef* op_def = nullptr; - Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def); - if (!s.ok() || op_def == nullptr) { - return false; - } - return HasAttr(op_def, attr_name); - } + bool HasAttr(const string& op_type_name, const string& attr_name) const; bool HasAttr(const OpDef* op_def, const string& attr_name) const { for (auto attr : op_def->attr()) { @@ -548,17 +445,7 @@ class GraphDatasetBase : public DatasetBase { private: Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, - string* output_node) const { - GraphDefBuilder b; - DatasetGraphDefBuilder db(&b); - Node* node = nullptr; - TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node)); - *output_node = node->name(); - GraphDef graph_def; - TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); - graph_def.SerializeToString(serialized_graph_def); - return Status::OK(); - } + string* output_node) const; const string op_name_; }; diff --git a/tensorflow/core/kernels/dataset_utils.cc b/tensorflow/core/kernels/dataset_utils.cc index cd58c80912..bd20e20cad 100644 --- a/tensorflow/core/kernels/dataset_utils.cc +++ b/tensorflow/core/kernels/dataset_utils.cc @@ -32,7 +32,7 @@ Status MakeIteratorFromInputElement( // is always 0, so a negative random step ID should suffice. opts.step_id = CapturedFunction::generate_step_id(); ScopedStepContainer step_container( - opts.step_id, [captured_func, ctx](const string& name) { + opts.step_id, [captured_func](const string& name) { captured_func->resource_manager()->Cleanup(name).IgnoreError(); }); opts.step_container = &step_container; diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index 2759ecb2f1..a5fd07fbe1 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -373,8 +373,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> { // If in_depth==1, this operation is just a standard convolution, so // invoke that op. if (std::is_same<T, float>::value && in_depth == 1) { + // TODO(yangzihao): Send in arbitrary dilation rates after the dilated + // conv is supported. launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, - stride_, stride_, padding_, output, data_format_); + /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_, + padding_, output, data_format_); return; } diff --git a/tensorflow/core/kernels/filter_dataset_op.cc b/tensorflow/core/kernels/filter_dataset_op.cc index e4d80e4ce3..67417d467d 100644 --- a/tensorflow/core/kernels/filter_dataset_op.cc +++ b/tensorflow/core/kernels/filter_dataset_op.cc @@ -95,7 +95,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { DataTypeVector other_arguments_types; other_arguments_types.reserve(captured_func_->captured_inputs().size()); - std::vector<NodeBuilder::NodeOut> other_arguments; + std::vector<Node*> other_arguments; other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; @@ -149,7 +149,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { FunctionLibraryRuntime::Options opts; opts.step_id = CapturedFunction::generate_step_id(); ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { + opts.step_id, [this](const string& name) { dataset() ->captured_func_->resource_manager() ->Cleanup(name) diff --git a/tensorflow/core/kernels/flat_map_dataset_op.cc b/tensorflow/core/kernels/flat_map_dataset_op.cc index ac1689e5bf..8fe8489371 100644 --- a/tensorflow/core/kernels/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/flat_map_dataset_op.cc @@ -102,7 +102,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { DataTypeVector other_arguments_types; other_arguments_types.reserve(captured_func_->captured_inputs().size()); - std::vector<NodeBuilder::NodeOut> other_arguments; + std::vector<Node*> other_arguments; other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; diff --git a/tensorflow/core/kernels/group_by_window_dataset_op.cc b/tensorflow/core/kernels/group_by_window_dataset_op.cc index 8644bcf9b5..604555a560 100644 --- a/tensorflow/core/kernels/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/group_by_window_dataset_op.cc @@ -169,7 +169,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { opts.step_id = CapturedFunction::generate_step_id(); opts.runner = ctx->runner(); ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { + opts.step_id, [this](const string& name) { dataset() ->captured_key_func_->resource_manager() ->Cleanup(name) @@ -198,7 +198,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { opts2.step_id = CapturedFunction::generate_step_id(); opts2.runner = ctx->runner(); ScopedStepContainer step_container2( - opts2.step_id, [this, ctx](const string& name) { + opts2.step_id, [this](const string& name) { dataset() ->captured_window_size_func_->resource_manager() ->Cleanup(name) @@ -257,7 +257,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { opts.step_id = CapturedFunction::generate_step_id(); opts.runner = ctx->runner(); ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { + opts.step_id, [this](const string& name) { dataset() ->captured_reduce_func_->resource_manager() ->Cleanup(name) diff --git a/tensorflow/core/kernels/interleave_dataset_op.cc b/tensorflow/core/kernels/interleave_dataset_op.cc index cbee68b2db..833e8cb9c5 100644 --- a/tensorflow/core/kernels/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/interleave_dataset_op.cc @@ -126,7 +126,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node)); DataTypeVector other_arguments_types; other_arguments_types.reserve(captured_func_->captured_inputs().size()); - std::vector<NodeBuilder::NodeOut> other_arguments; + std::vector<Node*> other_arguments; other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; diff --git a/tensorflow/core/kernels/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/map_and_batch_dataset_op.cc index ad1e356dbd..9bd66e681f 100644 --- a/tensorflow/core/kernels/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/map_and_batch_dataset_op.cc @@ -239,8 +239,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { // to unblock a consumer. FunctionLibraryRuntime::Options opts; opts.step_id = CapturedFunction::generate_step_id(); - ScopedStepContainer* step_container = new ScopedStepContainer( - opts.step_id, [this, ctx](const string& name) { + ScopedStepContainer* step_container = + new ScopedStepContainer(opts.step_id, [this](const string& name) { dataset() ->captured_func_->resource_manager() ->Cleanup(name) diff --git a/tensorflow/core/kernels/map_dataset_op.cc b/tensorflow/core/kernels/map_dataset_op.cc index 4ba09bc335..29899a987e 100644 --- a/tensorflow/core/kernels/map_dataset_op.cc +++ b/tensorflow/core/kernels/map_dataset_op.cc @@ -100,7 +100,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { DataTypeVector other_arguments_types( captured_func_->captured_inputs().size()); - std::vector<NodeBuilder::NodeOut> other_arguments( + std::vector<Node*> other_arguments( captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; @@ -146,7 +146,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { FunctionLibraryRuntime::Options opts; opts.step_id = CapturedFunction::generate_step_id(); ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { + opts.step_id, [this](const string& name) { dataset() ->captured_func_->resource_manager() ->Cleanup(name) diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc index 138acdf298..9fee94f946 100644 --- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc @@ -28,6 +28,7 @@ limitations under the License. #if defined(INTEL_MKL) #include <vector> #include "mkl_cblas.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -72,10 +73,10 @@ class BatchMatMulMkl : public OpKernel { TensorShape out_shape; for (int i = 0; i < ndims - 2; ++i) { OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i), - errors::InvalidArgument("lhs.dim(", i, ") and rhs.dim(", i, - ") must be the same: ", - lhs.shape().DebugString(), " vs ", - rhs.shape().DebugString())); + errors::InvalidArgument( + "lhs.dim(", i, ") and rhs.dim(", i, + ") must be the same: ", lhs.shape().DebugString(), " vs ", + rhs.shape().DebugString())); out_shape.AddDim(lhs.dim_size(i)); } auto batch_size = (ndims == 2) ? 1 : out_shape.num_elements(); @@ -109,7 +110,7 @@ class BatchMatMulMkl : public OpKernel { const uint64 M = lhs_reshaped.dimension(adj_x_ ? 2 : 1); const uint64 K = lhs_reshaped.dimension(adj_x_ ? 1 : 2); const uint64 N = rhs_reshaped.dimension(adj_y_ ? 1 : 2); - + std::vector<MKL_INT> m_array(batch_size, M); std::vector<MKL_INT> n_array(batch_size, N); std::vector<MKL_INT> k_array(batch_size, K); @@ -128,7 +129,7 @@ class BatchMatMulMkl : public OpKernel { b_array.push_back(&rhs_reshaped(i, 0, 0)); c_array.push_back(&out_reshaped(i, 0, 0)); } - + MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, &m_array[0], &n_array[0], &k_array[0], &a_array[0], &lda_array[0], &b_array[0], &ldb_array[0], &c_array[0], &ldc_array[0], 1, diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc index 8c0109f5c8..d086abb247 100644 --- a/tensorflow/core/kernels/multinomial_op.cc +++ b/tensorflow/core/kernels/multinomial_op.cc @@ -40,7 +40,7 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { -template <typename Device, typename T> +template <typename Device, typename T, typename OutputType> struct MultinomialFunctor { void operator()(OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix logits, @@ -49,11 +49,11 @@ struct MultinomialFunctor { typename TTypes<float>::Flat scratch, int batch_size, int num_classes, int num_samples, const random::PhiloxRandom& gen, - typename TTypes<int64>::Matrix output); + typename TTypes<OutputType>::Matrix output); }; -template <typename T> -struct MultinomialFunctor<CPUDevice, T> { +template <typename T, typename OutputType> +struct MultinomialFunctor<CPUDevice, T, OutputType> { void operator()(OpKernelContext* ctx, const CPUDevice& d, typename TTypes<T>::ConstMatrix logits, typename TTypes<float>::Flat /* noises */, @@ -61,7 +61,7 @@ struct MultinomialFunctor<CPUDevice, T> { typename TTypes<float>::Flat /* scratch */, int batch_size, int num_classes, int num_samples, const random::PhiloxRandom& gen, - typename TTypes<int64>::Matrix output) { + typename TTypes<OutputType>::Matrix output) { auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); // The implementation only parallelizes by batch. @@ -128,7 +128,7 @@ struct MultinomialFunctor<CPUDevice, T> { } // namespace functor // Samples from a multinomial distribution. -template <typename Device, typename T> +template <typename Device, typename T, typename OutputType> class MultinomialOp : public OpKernel { public: explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) { @@ -195,11 +195,11 @@ class MultinomialOp : public OpKernel { if (std::is_same<Device, CPUDevice>::value) num_samples_ceil_4 *= 2; auto rng = generator_.ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256); - functor::MultinomialFunctor<Device, T>()( + functor::MultinomialFunctor<Device, T, OutputType>()( ctx, ctx->eigen_device<Device>(), logits_t.matrix<T>(), noises.flat<float>(), scores.flat<float>(), scratch.flat<float>(), batch_size, num_classes, num_samples, rng, - samples_t->matrix<int64>()); + samples_t->matrix<OutputType>()); } } @@ -209,10 +209,17 @@ class MultinomialOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(MultinomialOp); }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("Multinomial").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ - MultinomialOp<CPUDevice, TYPE>); +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("Multinomial") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint("output_dtype", DT_INT32), \ + MultinomialOp<CPUDevice, TYPE, int32>); \ + REGISTER_KERNEL_BUILDER(Name("Multinomial") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint("output_dtype", DT_INT64), \ + MultinomialOp<CPUDevice, TYPE, int64>); TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); @@ -220,12 +227,20 @@ TF_CALL_double(REGISTER); #undef REGISTER #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("Multinomial") \ - .Device(DEVICE_GPU) \ - .HostMemory("num_samples") \ - .TypeConstraint<TYPE>("T"), \ - MultinomialOp<GPUDevice, TYPE>) +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("Multinomial") \ + .Device(DEVICE_GPU) \ + .HostMemory("num_samples") \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint("output_dtype", DT_INT32), \ + MultinomialOp<GPUDevice, TYPE, int32>) \ + REGISTER_KERNEL_BUILDER(Name("Multinomial") \ + .Device(DEVICE_GPU) \ + .HostMemory("num_samples") \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint("output_dtype", DT_INT64), \ + MultinomialOp<GPUDevice, TYPE, int64>) + TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); diff --git a/tensorflow/core/kernels/multinomial_op.h b/tensorflow/core/kernels/multinomial_op.h index af5e81f219..6e41060aa4 100644 --- a/tensorflow/core/kernels/multinomial_op.h +++ b/tensorflow/core/kernels/multinomial_op.h @@ -21,7 +21,7 @@ namespace tensorflow { namespace functor { // Generic helper functor for the Multinomial Op. -template <typename Device, typename T> +template <typename Device, typename T, typename OutputType> struct MultinomialFunctor; } // namespace functor diff --git a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc index 19b4f3ca55..5cc5877cce 100644 --- a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc +++ b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc @@ -37,20 +37,22 @@ using GPUDevice = Eigen::GpuDevice; // Kernel for Multinomial op. Data is interpreted to have the following shapes: // scores: [B, S, C]; maxima: [B, S]; output: [B, S]. +template <typename OutputType> __global__ void MultinomialKernel(int32 nthreads, const int32 num_classes, const int32 num_samples, const float* scores, - const float* maxima, int64* output) { + const float* maxima, OutputType* output) { CUDA_1D_KERNEL_LOOP(index, nthreads) { const int maxima_idx = index / num_classes; if (ldg(maxima + maxima_idx) == ldg(scores + index)) { - CudaAtomicMax(reinterpret_cast<uint64*>(output + maxima_idx), - static_cast<uint64>(index % num_classes)); + using UnsignedOutputType = typename std::make_unsigned<OutputType>::type; + CudaAtomicMax(reinterpret_cast<UnsignedOutputType*>(output + maxima_idx), + static_cast<UnsignedOutputType>(index % num_classes)); } } } -template <typename T> -struct MultinomialFunctor<GPUDevice, T> { +template <typename T, typename OutputType> +struct MultinomialFunctor<GPUDevice, T, OutputType> { void operator()(OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T>::ConstMatrix logits, typename TTypes<float>::Flat noises, @@ -58,7 +60,7 @@ struct MultinomialFunctor<GPUDevice, T> { typename TTypes<float>::Flat maxima, int batch_size, int num_classes, int num_samples, const random::PhiloxRandom& gen, - typename TTypes<int64>::Matrix output) { + typename TTypes<OutputType>::Matrix output) { // Uniform, [0, 1). typedef random::UniformDistribution<random::PhiloxRandom, float> Dist; functor::FillPhiloxRandom<GPUDevice, Dist>()(ctx, d, gen, noises.data(), @@ -111,11 +113,17 @@ struct MultinomialFunctor<GPUDevice, T> { }; // Explicit instantiation of the GPU functors. -template struct MultinomialFunctor<GPUDevice, Eigen::half>; -template struct MultinomialFunctor<GPUDevice, float>; -template struct MultinomialFunctor<GPUDevice, double>; -template struct MultinomialFunctor<GPUDevice, int32>; -template struct MultinomialFunctor<GPUDevice, int64>; +template struct MultinomialFunctor<GPUDevice, Eigen::half, int32>; +template struct MultinomialFunctor<GPUDevice, float, int32>; +template struct MultinomialFunctor<GPUDevice, double, int32>; +template struct MultinomialFunctor<GPUDevice, int32, int32>; +template struct MultinomialFunctor<GPUDevice, int64, int32>; + +template struct MultinomialFunctor<GPUDevice, Eigen::half, int64>; +template struct MultinomialFunctor<GPUDevice, float, int64>; +template struct MultinomialFunctor<GPUDevice, double, int64>; +template struct MultinomialFunctor<GPUDevice, int32, int64>; +template struct MultinomialFunctor<GPUDevice, int64, int64>; } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/nn_ops_test.cc b/tensorflow/core/kernels/nn_ops_test.cc index 0db7c63b8b..a841291ddd 100644 --- a/tensorflow/core/kernels/nn_ops_test.cc +++ b/tensorflow/core/kernels/nn_ops_test.cc @@ -653,6 +653,8 @@ BM_ConvFloatDepthwiseFwd(32, 7, 7, 1024, 1, 1024, 3, 3, 1, SAME, conv6); // Benchmarks with different stride and padding options. BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 2, SAME, conv7); BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 2, VALID, conv8); +BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 3, 3, 1, SAME, conv9); +BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 5, 5, 1, SAME, conv10); #define BM_ConvFloatDepthwiseBk(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, LABEL) \ static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL(int iters) { \ diff --git a/tensorflow/core/kernels/padded_batch_dataset_op.cc b/tensorflow/core/kernels/padded_batch_dataset_op.cc index 7c28d955e1..cef5bde156 100644 --- a/tensorflow/core/kernels/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/padded_batch_dataset_op.cc @@ -242,7 +242,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { Node* batch_size = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); - std::vector<NodeBuilder::NodeOut> padded_shapes; + std::vector<Node*> padded_shapes; padded_shapes.reserve(padded_shapes_.size()); for (int i = 0; i < padded_shapes_.size(); i++) { Node* node; @@ -254,7 +254,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { padded_shapes.emplace_back(node); } - std::vector<NodeBuilder::NodeOut> padding_values; + std::vector<Node*> padding_values; padding_values.reserve(padding_values_.size()); for (const Tensor& t : padding_values_) { Node* node; diff --git a/tensorflow/core/kernels/parallel_map_dataset_op.cc b/tensorflow/core/kernels/parallel_map_dataset_op.cc index 2be87f4bde..b9175fe904 100644 --- a/tensorflow/core/kernels/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/parallel_map_dataset_op.cc @@ -195,8 +195,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { FunctionLibraryRuntime::Options opts; opts.step_id = CapturedFunction::generate_step_id(); - ScopedStepContainer* step_container = new ScopedStepContainer( - opts.step_id, [this, ctx](const string& name) { + ScopedStepContainer* step_container = + new ScopedStepContainer(opts.step_id, [this](const string& name) { dataset() ->captured_func_->resource_manager() ->Cleanup(name) diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc index 3b0764bb9b..f83998e0c1 100644 --- a/tensorflow/core/kernels/quantized_conv_ops.cc +++ b/tensorflow/core/kernels/quantized_conv_ops.cc @@ -457,6 +457,19 @@ class QuantizedConv2DOp : public OpKernel { context, (strides_[0] == 1 && strides_[3] == 1), errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + std::vector<int32> dilations; + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations)); + OP_REQUIRES(context, dilations.size() == 4, + errors::InvalidArgument("Dilations field must " + "specify 4 dimensions")); + OP_REQUIRES(context, dilations[1] == 1 && dilations[2] == 1, + errors::InvalidArgument( + "Current implementation only supports dilated rate as 1 " + "in the row and column dimensions.")); + OP_REQUIRES(context, (dilations[0] == 1 && dilations[3] == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } diff --git a/tensorflow/core/kernels/random_dataset_op.cc b/tensorflow/core/kernels/random_dataset_op.cc new file mode 100644 index 0000000000..03d481a593 --- /dev/null +++ b/tensorflow/core/kernels/random_dataset_op.cc @@ -0,0 +1,154 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/dataset.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { + +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class RandomDatasetOp : public DatasetOpKernel { + public: + explicit RandomDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + int64 seed; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed)); + + int64 seed2; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2)); + + // By TensorFlow convention, passing 0 for both seeds indicates + // that the shuffling should be seeded non-deterministically. + if (seed == 0 && seed2 == 0) { + seed = random::New64(); + seed2 = random::New64(); + } + + *output = new Dataset(ctx, seed, seed2); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, int64 seed, int64 seed2) + : GraphDatasetBase(ctx), seed_(seed), seed2_(seed2) {} + + std::unique_ptr<IteratorBase> MakeIterator( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Random")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_INT64}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}}); + return *shapes; + } + + string DebugString() override { + return strings::StrCat("RandomDatasetOp(", seed_, ", ", seed2_, + ")::Dataset"); + } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + Node* seed = nullptr; + Node* seed2 = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); + TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {seed, seed2}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + parent_generator_(dataset()->seed_, dataset()->seed2_), + generator_(&parent_generator_) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + Tensor value_tensor(cpu_allocator(), DT_INT64, {}); + value_tensor.scalar<int64>()() = Random(); + out_tensors->emplace_back(std::move(value_tensor)); + *end_of_sequence = false; + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"), + num_random_samples_)); + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"), + &num_random_samples_)); + parent_generator_ = + random::PhiloxRandom(dataset()->seed_, dataset()->seed2_); + generator_ = random::SingleSampleAdapter<random::PhiloxRandom>( + &parent_generator_); + generator_.Skip(num_random_samples_); + return Status::OK(); + } + + private: + random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random() + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + num_random_samples_++; + auto out = generator_(); + return out; + } + mutex mu_; + random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); + random::SingleSampleAdapter<random::PhiloxRandom> generator_ + GUARDED_BY(mu_); + int64 num_random_samples_ GUARDED_BY(mu_) = 0; + }; + + const int64 seed_; + const int64 seed2_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU), + RandomDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc index 807ac0a456..5c537c5b9c 100644 --- a/tensorflow/core/kernels/reduction_ops_min.cc +++ b/tensorflow/core/kernels/reduction_ops_min.cc @@ -50,6 +50,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); .TypeConstraint<int64>("Tidx") \ .HostMemory("reduction_indices"), \ ReductionOp<GPUDevice, type, int64, Eigen::internal::MinReducer<type>>); +REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(float); REGISTER_GPU_KERNELS(double); diff --git a/tensorflow/core/kernels/reduction_ops_test.cc b/tensorflow/core/kernels/reduction_ops_test.cc index 9bbe993a2f..fe8ea59f1b 100644 --- a/tensorflow/core/kernels/reduction_ops_test.cc +++ b/tensorflow/core/kernels/reduction_ops_test.cc @@ -174,6 +174,11 @@ static void BM_Min2DToScalarGPU(int iters, int num_x, int num_y) { } BENCHMARK(BM_Min2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192); +static void BM_Min2DToScalarGPUHalf(int iters, int num_x, int num_y) { + ReduceToScalar<Eigen::half>(iters, "gpu", "Min", num_x, num_y); +} +BENCHMARK(BM_Min2DToScalarGPUHalf)->RangePair(2048, 8192, 2048, 8192); + static void BM_Bool2DToScalarGPU(int iters, int num_x, int num_y) { ReduceToScalar<bool>(iters, "gpu", "All", num_x, num_y); } diff --git a/tensorflow/core/kernels/scan_dataset_op.cc b/tensorflow/core/kernels/scan_dataset_op.cc index 76c219f1ae..bc52322022 100644 --- a/tensorflow/core/kernels/scan_dataset_op.cc +++ b/tensorflow/core/kernels/scan_dataset_op.cc @@ -132,7 +132,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { FunctionLibraryRuntime::Options opts; opts.step_id = CapturedFunction::generate_step_id(); ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { + opts.step_id, [this](const string& name) { dataset() ->captured_func_->resource_manager() ->Cleanup(name) diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 484932ab01..98c0181afb 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -21,6 +21,7 @@ limitations under the License. #endif // GOOGLE_CUDA #include "tensorflow/core/kernels/scatter_nd_op.h" + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -28,6 +29,8 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/dense_update_functor.h" #include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/training_op_helpers.h" +#include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -83,7 +86,10 @@ class ScatterNdUpdateOp : public OpKernel { const DataType dt = DataTypeToEnum<T>::v(); const DataType dt_ref = DataTypeToEnum<T>::ref(); const DataType index_t = DataTypeToEnum<Index>::v(); - if (IsRefType(c->input_type(0))) { + dtype_ = c->input_type(0); + if (c->input_type(0) == DT_RESOURCE) { + // TODO(apassos): what to validate here? + } else if (IsRefType(c->input_type(0))) { OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref})); OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_)); } else { @@ -93,7 +99,16 @@ class ScatterNdUpdateOp : public OpKernel { } void Compute(OpKernelContext* c) override { - if (use_exclusive_lock_) { + if (dtype_ == DT_RESOURCE) { + if (use_exclusive_lock_) { + Var* v; + OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); + mutex_lock m(*v->mu()); + DoCompute(c); + } else { + DoCompute(c); + } + } else if (use_exclusive_lock_) { // If we're here, it means the input type is a ref. DCHECK(IsRefType(c->input_dtype(0))); // Hold mutex while we apply updates @@ -105,6 +120,7 @@ class ScatterNdUpdateOp : public OpKernel { } private: + DataType dtype_; bool use_exclusive_lock_; void DoCompute(OpKernelContext* c) { @@ -113,7 +129,20 @@ class ScatterNdUpdateOp : public OpKernel { Tensor params; TensorShape params_shape; - if (IsRefType(c->input_dtype(0))) { + if (dtype_ == DT_RESOURCE) { + Var* v; + OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); + Tensor* t = v->tensor(); + if (!use_exclusive_lock_) { + // We're not holding the lock in the outer scope so need it here. + mutex_lock m(*v->mu()); + OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t)); + } else { + OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t)); + } + params = *t; + params_shape = params.shape(); + } else if (IsRefType(c->input_dtype(0))) { params = c->mutable_input(0, use_exclusive_lock_); params_shape = params.shape(); c->forward_ref_input_to_ref_output(0, 0); @@ -159,6 +188,16 @@ class ScatterNdUpdateOp : public OpKernel { .TypeConstraint<index_type>("Tindices"), \ ScatterNdUpdateOp<dev##Device, type, index_type, op>) +#define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \ + dev, name, op) \ + REGISTER_KERNEL_BUILDER( \ + Name(name) \ + .Device(DEVICE_##dev) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices") \ + .HostMemory("ref"), \ + ScatterNdUpdateOp<dev##Device, type, index_type, op>) + #define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \ REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \ REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name) @@ -167,6 +206,11 @@ class ScatterNdUpdateOp : public OpKernel { REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \ REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op) +#define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \ + REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \ + op); \ + REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op) + #define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \ scatter_nd_op::UpdateOp::ADD); \ @@ -178,9 +222,11 @@ class ScatterNdUpdateOp : public OpKernel { #define REGISTER_SCATTER_ND(type, dev) \ REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd"); -#define REGISTER_SCATTER_ND_UPDATE(type, dev) \ - REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \ - scatter_nd_op::UpdateOp::ASSIGN); +#define REGISTER_SCATTER_ND_UPDATE(type, dev) \ + REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \ + scatter_nd_op::UpdateOp::ASSIGN); \ + REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \ + type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN); // Registers CPU kernels. #define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \ @@ -281,8 +327,7 @@ Status ValidateUpdateShape(const TensorShape& params_shape, } template <typename Index> -Status PrepareAndValidateInputs(OpKernelContext* c, - const TensorShape& params_shape, +Status PrepareAndValidateInputs(const TensorShape& params_shape, const Tensor& indices, const Tensor& updates, int64* slice_dim, Index* num_updates, Index* slice_size) { @@ -396,7 +441,7 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices, Index num_updates; Index slice_size; TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>( - c, shape, indices, updates, &slice_dim, &num_updates, &slice_size)); + shape, indices, updates, &slice_dim, &num_updates, &slice_size)); IndexFlattener<Device, Index> index_flattener; auto indices_flat = index_flattener(c, indices); diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc index cfb86904d5..f4159da229 100644 --- a/tensorflow/core/kernels/serialize_sparse_op.cc +++ b/tensorflow/core/kernels/serialize_sparse_op.cc @@ -409,186 +409,11 @@ class DeserializeSparseOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS -template <typename T> -class DeserializeManySparseOp : public OpKernel { - public: - explicit DeserializeManySparseOp(OpKernelConstruction* context) - : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - const Tensor& serialized_sparse = context->input(0); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(serialized_sparse.shape()), - errors::InvalidArgument( - "Serialized sparse should be a matrix but received shape ", - serialized_sparse.shape().DebugString())); - OP_REQUIRES( - context, serialized_sparse.shape().dim_size(1) == 3, - errors::InvalidArgument( - "Serialized sparse should have 3 columns but received shape ", - serialized_sparse.shape().DebugString())); - - int num_sparse_tensors = serialized_sparse.shape().dim_size(0); - - OP_REQUIRES( - context, num_sparse_tensors > 0, - errors::InvalidArgument("Must have at least 1 serialized SparseTensor, " - "but input matrix has 0 rows")); - - std::vector<Tensor> indices_to_concat; - std::vector<Tensor> values_to_concat; - std::vector<TensorShape> shapes_to_concat; - - const auto& serialized_sparse_t = serialized_sparse.matrix<string>(); - - for (int i = 0; i < num_sparse_tensors; ++i) { - Tensor output_indices(DT_INT64); - Tensor output_values(DataTypeToEnum<T>::value); - Tensor output_shape(DT_INT64); - TensorProto proto_indices; - TensorProto proto_values; - TensorProto proto_shape; - - OP_REQUIRES( - context, - ParseProtoUnlimited(&proto_indices, serialized_sparse_t(i, 0)), - errors::InvalidArgument("Could not parse serialized_sparse[", i, - ", 0]")); - OP_REQUIRES(context, - ParseProtoUnlimited(&proto_values, serialized_sparse_t(i, 1)), - errors::InvalidArgument("Could not parse serialized_sparse[", - i, ", 1]")); - OP_REQUIRES(context, - ParseProtoUnlimited(&proto_shape, serialized_sparse_t(i, 2)), - errors::InvalidArgument("Could not parse serialized_sparse[", - i, ", 2]")); - - OP_REQUIRES(context, output_indices.FromProto(proto_indices), - errors::InvalidArgument( - "Could not construct Tensor serialized_sparse[", i, - ", 0] (indices)")); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()), - errors::InvalidArgument( - "Expected serialized_sparse[", i, - ", 0] to represent an index matrix but received shape ", - output_indices.shape().DebugString())); - OP_REQUIRES(context, output_values.FromProto(proto_values), - errors::InvalidArgument( - "Could not construct Tensor serialized_sparse[", i, - ", 1] (values)")); - OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()), - errors::InvalidArgument( - "Expected serialized_sparse[", i, - ", 1] to represent a values vector but received shape ", - output_values.shape().DebugString())); - OP_REQUIRES(context, output_shape.FromProto(proto_shape), - errors::InvalidArgument( - "Could not construct Tensor serialized_sparse[", i, - ", 2] (shape)")); - OP_REQUIRES( - context, TensorShapeUtils::IsVector(output_shape.shape()), - errors::InvalidArgument("Expected serialized_sparse[", i, - ", 1] to be a shape vector but its shape is ", - output_shape.shape().DebugString())); - - OP_REQUIRES( - context, DataTypeToEnum<T>::value == output_values.dtype(), - errors::InvalidArgument( - "Requested SparseTensor of type ", - DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i, - "].values.dtype() == ", DataTypeString(output_values.dtype()))); - - int64 num_entries = output_indices.dim_size(0); - OP_REQUIRES(context, num_entries == output_values.dim_size(0), - errors::InvalidArgument( - "Expected row counts of SparseTensor[", i, - "].indices and SparseTensor[", i, - "].values to match but they do not: ", num_entries, - " vs. ", output_values.dim_size(0))); - int rank = output_indices.dim_size(1); - OP_REQUIRES( - context, rank == output_shape.dim_size(0), - errors::InvalidArgument("Expected column counts of SparseTensor[", i, - "].indices to match size of SparseTensor[", i, - "].shape " - "but they do not: ", - rank, " vs. ", output_shape.dim_size(0))); - - // Now we expand each SparseTensors' indices and shape by - // prefixing a dimension - Tensor expanded_indices( - DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)})); - Tensor expanded_shape(DT_INT64, - TensorShape({1 + output_shape.dim_size(0)})); - const auto& output_indices_t = output_indices.matrix<int64>(); - const auto& output_shape_t = output_shape.vec<int64>(); - auto expanded_indices_t = expanded_indices.matrix<int64>(); - auto expanded_shape_t = expanded_shape.vec<int64>(); - expanded_indices_t.chip<1>(0).setZero(); - Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1); - Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank); - expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t; - expanded_shape_t(0) = 1; - std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1)); - - TensorShape expanded_tensor_shape(expanded_shape.vec<int64>()); - - indices_to_concat.push_back(expanded_indices); - values_to_concat.push_back(output_values); - shapes_to_concat.push_back(expanded_tensor_shape); - } - - int rank = -1; - for (int i = 0; i < num_sparse_tensors; ++i) { - if (rank < 0) rank = shapes_to_concat[i].dims(); - OP_REQUIRES(context, rank == shapes_to_concat[i].dims(), - errors::InvalidArgument( - "Inconsistent rank across SparseTensors: rank prior to " - "SparseTensor[", - i, "] was: ", rank, " but rank of SparseTensor[", i, - "] is: ", shapes_to_concat[i].dims())); - } - - // SparseTensor::Concat requires consistent shape for all but the - // primary order dimension (dimension 0 in this case). So we get - // the maximum value across all the input SparseTensors for each - // dimension and use that. - TensorShape preconcat_shape(shapes_to_concat[0]); - for (int i = 0; i < num_sparse_tensors; ++i) { - for (int d = 0; d < rank; ++d) { - preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d), - shapes_to_concat[i].dim_size(d))); - } - } - - // Dimension 0 is the primary dimension. - gtl::InlinedVector<int64, 8> std_order(rank); - std::iota(std_order.begin(), std_order.end(), 0); - - std::vector<SparseTensor> tensors_to_concat; - tensors_to_concat.reserve(num_sparse_tensors); - for (int i = 0; i < num_sparse_tensors; ++i) { - tensors_to_concat.emplace_back(indices_to_concat[i], values_to_concat[i], - preconcat_shape, std_order); - } - - SparseTensor output = SparseTensor::Concat<T>(tensors_to_concat); - - Tensor final_output_shape(DT_INT64, TensorShape({output.dims()})); - - std::copy_n(output.shape().data(), output.dims(), - final_output_shape.vec<int64>().data()); - - context->set_output(0, output.indices()); - context->set_output(1, output.values()); - context->set_output(2, final_output_shape); - } -}; - #define REGISTER_KERNELS(type) \ REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("dtype"), \ - DeserializeManySparseOp<type>) + DeserializeSparseOp<type>) TF_CALL_ALL_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS diff --git a/tensorflow/core/kernels/softmax_op_functor.h b/tensorflow/core/kernels/softmax_op_functor.h index 1f38bdce8c..d3a267ed87 100644 --- a/tensorflow/core/kernels/softmax_op_functor.h +++ b/tensorflow/core/kernels/softmax_op_functor.h @@ -64,23 +64,21 @@ struct SoftmaxEigenImpl { one_by_class.set(1, num_classes); #endif // shifted_logits = logits - max(logits along classes); - auto shifted_logits = (logits - - logits.maximum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); + auto shifted_logits = (logits - logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); if (log) { // Calculate the log of the softmax // softmax = logits - max(logits along classes); softmax.device(d) = shifted_logits; // softmax = softmax - log(sum(exp(softmax along classes))); - softmax.device(d) = (softmax - - softmax.exp() - .sum(along_class) - .eval() - .reshape(batch_by_one) - .log() - .broadcast(one_by_class)); + softmax.device(d) = (softmax - softmax.exp() + .sum(along_class) + .log() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); } else { // NOTE(touts): If you modify this implementation please run // the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc. @@ -88,12 +86,11 @@ struct SoftmaxEigenImpl { // softmax = exp(logits - max(logits along classes)); softmax.device(d) = shifted_logits.exp(); // softmax = softmax * (1 / sum(softmax along classes)); - softmax.device(d) = (softmax * - softmax.sum(along_class) - .inverse() - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); + softmax.device(d) = (softmax * softmax.sum(along_class) + .inverse() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); } } }; diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 8fc40db3cc..73b6d4cf6a 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -427,6 +427,7 @@ REGISTER_STRIDED_SLICE(bfloat16); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU); +TF_CALL_int64(REGISTER_GPU); // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel diff --git a/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc b/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc index a8487f49f4..8ca27e3b92 100644 --- a/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc +++ b/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc @@ -53,6 +53,7 @@ typedef Eigen::GpuDevice GPUDevice; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); TF_CALL_complex64(DEFINE_GPU_KERNELS); TF_CALL_complex128(DEFINE_GPU_KERNELS); +TF_CALL_int64(DEFINE_GPU_KERNELS); DEFINE_GPU_KERNELS(int32); #undef DEFINE_GPU_KERNELS diff --git a/tensorflow/core/kernels/tensor_dataset_op.cc b/tensorflow/core/kernels/tensor_dataset_op.cc index fe53434d17..5cf9931188 100644 --- a/tensorflow/core/kernels/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/tensor_dataset_op.cc @@ -70,7 +70,7 @@ class TensorDatasetOp : public DatasetOpKernel { protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, Node** output) const override { - std::vector<NodeBuilder::NodeOut> components; + std::vector<Node*> components; components.reserve(tensors_.size()); for (const Tensor& t : tensors_) { Node* node; diff --git a/tensorflow/core/kernels/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/tensor_slice_dataset_op.cc index e85f59b584..19d4816ff3 100644 --- a/tensorflow/core/kernels/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/tensor_slice_dataset_op.cc @@ -86,7 +86,7 @@ class TensorSliceDatasetOp : public DatasetOpKernel { protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, Node** output) const override { - std::vector<NodeBuilder::NodeOut> components; + std::vector<Node*> components; components.reserve(tensors_.size()); for (const Tensor& t : tensors_) { Node* node; diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc index 9381915ae9..31e5737f62 100644 --- a/tensorflow/core/kernels/zip_dataset_op.cc +++ b/tensorflow/core/kernels/zip_dataset_op.cc @@ -80,7 +80,7 @@ class ZipDatasetOp : public DatasetOpKernel { protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - std::vector<NodeBuilder::NodeOut> input_graph_nodes; + std::vector<Node*> input_graph_nodes; input_graph_nodes.reserve(inputs_.size()); for (const auto& input : inputs_) { Node* input_node; diff --git a/tensorflow/core/lib/core/arena.cc b/tensorflow/core/lib/core/arena.cc index 2a04f7bd39..55e481d0e6 100644 --- a/tensorflow/core/lib/core/arena.cc +++ b/tensorflow/core/lib/core/arena.cc @@ -28,6 +28,7 @@ limitations under the License. #include <algorithm> #include <vector> +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mem.h" @@ -113,24 +114,11 @@ void Arena::MakeNewBlock(const uint32 alignment) { CHECK(SatisfyAlignment(alignment)); } -// The following simple numeric routines also exist in util/math/mathutil.h -// but we don't want to depend on that library. - -// Euclid's algorithm for Greatest Common Denominator. -static uint32 GCD(uint32 x, uint32 y) { - while (y != 0) { - uint32 r = x % y; - x = y; - y = r; - } - return x; -} - static uint32 LeastCommonMultiple(uint32 a, uint32 b) { if (a > b) { - return (a / GCD(a, b)) * b; + return (a / MathUtil::GCD<uint32>(a, b)) * b; } else if (a < b) { - return (b / GCD(b, a)) * a; + return (b / MathUtil::GCD<uint32>(b, a)) * a; } else { return a; } diff --git a/tensorflow/core/lib/math/math_util.h b/tensorflow/core/lib/math/math_util.h index 6f279865e7..9e71598622 100644 --- a/tensorflow/core/lib/math/math_util.h +++ b/tensorflow/core/lib/math/math_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LIB_MATH_MATH_UTIL_H_ #define TENSORFLOW_LIB_MATH_MATH_UTIL_H_ +#include <type_traits> + #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -59,6 +61,9 @@ class MathUtil { template <typename IntegralType, bool ceil> static IntegralType CeilOrFloorOfRatio(IntegralType numerator, IntegralType denominator); + + template <typename IntegralType> + static IntegralType GCD(IntegralType x, IntegralType y); }; // ---- CeilOrFloorOfRatio ---- @@ -107,6 +112,18 @@ IntegralType MathUtil::CeilOrFloorOfRatio(IntegralType numerator, } } +template <typename IntegralType> +IntegralType MathUtil::GCD(IntegralType a, IntegralType b) { + static_assert(std::is_unsigned<IntegralType>::value, + "signed GCD not supported!"); + while (b != 0) { + IntegralType r = a % b; + a = b; + b = r; + } + return a; +} + } // namespace tensorflow #endif // TENSORFLOW_LIB_MATH_MATH_UTIL_H_ diff --git a/tensorflow/core/lib/math/math_util_test.cc b/tensorflow/core/lib/math/math_util_test.cc index eaf8c31a43..a96e5467c3 100644 --- a/tensorflow/core/lib/math/math_util_test.cc +++ b/tensorflow/core/lib/math/math_util_test.cc @@ -195,4 +195,33 @@ TEST(MathUtil, CeilOfRatio) { #endif } +struct GCDTestCase { + unsigned int x; + unsigned int y; + unsigned int gcd; +}; + +TEST(MathUtil, GCD) { + std::vector<GCDTestCase> testcases({ + {10, 20, 10}, // + {27, 8, 1}, // + {4, 3, 1}, // + {6, 8, 2}, // + {5, 0, 5}, // + {5, 5, 5}, // + {0, 0, 0} // + }); + + for (const auto& tc : testcases) { + EXPECT_EQ(tc.gcd, MathUtil::GCD<uint32>(tc.x, tc.y)); + EXPECT_EQ(tc.gcd, MathUtil::GCD<uint32>(tc.y, tc.x)); + EXPECT_EQ(tc.gcd, MathUtil::GCD<uint64>(tc.x, tc.y)); + EXPECT_EQ(tc.gcd, MathUtil::GCD<uint64>(tc.y, tc.x)); + } + + const uint64 biggish_prime = 1666666667; + EXPECT_EQ(biggish_prime, + MathUtil::GCD<uint64>(biggish_prime * 3, biggish_prime * 4)); +} + } // namespace tensorflow diff --git a/tensorflow/core/lib/monitoring/collected_metrics.h b/tensorflow/core/lib/monitoring/collected_metrics.h index fbef25619f..acdb0d86ed 100644 --- a/tensorflow/core/lib/monitoring/collected_metrics.h +++ b/tensorflow/core/lib/monitoring/collected_metrics.h @@ -88,6 +88,7 @@ struct Point { ValueType value_type; int64 int64_value; string string_value; + bool bool_value; HistogramProto histogram_value; // start_timestamp and end_timestamp indicate the time period over which this diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h index 113d37e07d..2c8e250c56 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.h +++ b/tensorflow/core/lib/monitoring/collection_registry.h @@ -225,6 +225,12 @@ inline void CollectValue(const string& value, Point* const point) { } template <> +inline void CollectValue(const bool& value, Point* const point) { + point->value_type = ValueType::kBool; + point->bool_value = value; +} + +template <> inline void CollectValue(const HistogramProto& value, Point* const point) { point->value_type = ValueType::kHistogram; // This is inefficient. If and when we hit snags, we can change the API to do diff --git a/tensorflow/core/lib/monitoring/gauge.h b/tensorflow/core/lib/monitoring/gauge.h index 75471cfb22..ec978a9193 100644 --- a/tensorflow/core/lib/monitoring/gauge.h +++ b/tensorflow/core/lib/monitoring/gauge.h @@ -86,8 +86,29 @@ class GaugeCell<int64> { TF_DISALLOW_COPY_AND_ASSIGN(GaugeCell); }; +// Explicit specialization of GaugeCell<bool>. Compared to the primary +// template, it uses atomic values as opposed to mutex. This class is +// thread-safe. +template <> +class GaugeCell<bool> { + public: + explicit GaugeCell(bool value) : value_(value) {} + ~GaugeCell() {} + + // Atomically sets the value. + void Set(bool value); + + // Retrieves the current value. + bool value() const; + + private: + std::atomic<bool> value_; + + TF_DISALLOW_COPY_AND_ASSIGN(GaugeCell); +}; + // A stateful class for updating a gauge-like metric. Allowed ValueType are -// int64 and string. +// int64, string and bool. // // This class encapsulates a set of values (or a single value for a label-less // metric). Each value is identified by a tuple of labels. The class allows the @@ -117,6 +138,9 @@ class Gauge { // // auto* integer_gauge = Gauge<int64, 0>::New("/tensorflow/integer_gauge", // "Integer gauge") + // + // auto* bool_gauge = Gauge<bool, 0>::New("/tensorflow/bool_gauge", + // "Bool gauge") template <typename... MetricDefArgs> static Gauge* New(MetricDefArgs&&... metric_def_args); @@ -172,12 +196,17 @@ inline void GaugeCell<int64>::Set(int64 value) { value_ = value; } inline int64 GaugeCell<int64>::value() const { return value_; } +inline void GaugeCell<bool>::Set(bool value) { value_ = value; } + +inline bool GaugeCell<bool>::value() const { return value_; } + template <typename ValueType, int NumLabels> template <typename... MetricDefArgs> Gauge<ValueType, NumLabels>* Gauge<ValueType, NumLabels>::New( MetricDefArgs&&... metric_def_args) { static_assert(std::is_same<ValueType, int64>::value || - std::is_same<ValueType, string>::value, + std::is_same<ValueType, string>::value || + std::is_same<ValueType, bool>::value, "Gauge only allows int64 and string types."); return new Gauge<ValueType, NumLabels>( MetricDef<MetricKind::kGauge, ValueType, NumLabels>( diff --git a/tensorflow/core/lib/monitoring/gauge_test.cc b/tensorflow/core/lib/monitoring/gauge_test.cc index f98cfe2a3b..c8f673db38 100644 --- a/tensorflow/core/lib/monitoring/gauge_test.cc +++ b/tensorflow/core/lib/monitoring/gauge_test.cc @@ -87,6 +87,28 @@ TEST(GaugeOfStringValue, GetCell) { EXPECT_EQ("bar", same_cell->value()); } +auto* bool_gauge = + Gauge<bool, 0>::New("/tensorflow/test/bool_gauge", "Gauge of bool value."); + +TEST(GaugeOfBoolValue, InitializedWithFalseValue) { + EXPECT_EQ(false, bool_gauge->GetCell()->value()); +} + +TEST(GaugeOfBoolValue, GetCell) { + auto* cell = bool_gauge->GetCell(); + EXPECT_EQ(false, cell->value()); + + cell->Set(true); + EXPECT_EQ(true, cell->value()); + + auto* same_cell = bool_gauge->GetCell(); + EXPECT_EQ(true, cell->value()); + + same_cell->Set(false); + EXPECT_EQ(false, cell->value()); + EXPECT_EQ(false, same_cell->value()); +} + } // namespace } // namespace monitoring } // namespace tensorflow diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h index a7f14f9c94..f046842618 100644 --- a/tensorflow/core/lib/monitoring/metric_def.h +++ b/tensorflow/core/lib/monitoring/metric_def.h @@ -28,16 +28,16 @@ namespace monitoring { // The different metric kinds available. // // Gauge indicates that the metric's values are instantaneous measurements of a -// (typically) continuously varying quantity or a string value. Examples: a -// process's current heap size, a queue's current length, the name of the binary -// used by a process. +// (typically) continuously varying value. Examples: a process's current heap +// size, a queue's current length, the name of the binary used by a process, +// whether a task is complete. // // Cumulative indicates that the metric's values represent non-negative changes // over specified time periods. Example: the number of rpc calls to a service. enum class MetricKind : int { kGauge = 0, kCumulative }; // The type of the metric values. -enum class ValueType : int { kInt64 = 0, kHistogram, kString }; +enum class ValueType : int { kInt64 = 0, kHistogram, kString, kBool }; // Everything in the internal namespace is implementation details. Do not depend // on this. @@ -61,6 +61,11 @@ inline ValueType GetValueType<string>() { return ValueType::kString; } +template <> +inline ValueType GetValueType<bool>() { + return ValueType::kBool; +} + } // namespace internal // Abstract base class for a metric definition. diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 9fa6423d59..6f4ea09206 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -724,8 +724,8 @@ REGISTER_OP("OnesLike") .Input("x: T") .Output("y: T") .Attr( - "T: {float, double, int8, uint8, int16, uint16, int32, int64, " - "complex64, complex128, bool}") + "T: {bfloat16, float, double, int8, uint8, int16, uint16, int32, " + "int64, complex64, complex128, bool}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Returns a tensor of ones with the same shape and type as x. @@ -738,7 +738,7 @@ y: a tensor of the same shape and type as x but filled with ones. REGISTER_OP("Diag") .Input("diagonal: T") .Output("output: T") - .Attr("T: {float, double, int32, int64, complex64, complex128}") + .Attr("T: {bfloat16, float, double, int32, int64, complex64, complex128}") .SetShapeFn([](InferenceContext* c) { ShapeHandle in = c->input(0); TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in)); @@ -776,7 +776,7 @@ diagonal: Rank k tensor where k is at most 1. REGISTER_OP("DiagPart") .Input("input: T") .Output("diagonal: T") - .Attr("T: {float, double, int32, int64, complex64, complex128}") + .Attr("T: {bfloat16, float, double, int32, int64, complex64, complex128}") .SetShapeFn([](InferenceContext* c) { ShapeHandle in = c->input(0); if (!c->RankKnown(in)) { @@ -1059,9 +1059,8 @@ REGISTER_OP("Reverse") .Input("dims: bool") .Output("output: T") .Attr( - "T: {uint8, int8, uint16, int16, int32, int64, bool, half, float, " - "double, complex64, " - "complex128, string}") + "T: {uint8, int8, uint16, int16, int32, int64, bool, half, " + "float, double, complex64, complex128, string}") .SetShapeFn([](InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle dims; @@ -1137,9 +1136,8 @@ REGISTER_OP("ReverseV2") .Output("output: T") .Attr("Tidx: {int32, int64} = DT_INT32") .Attr( - "T: {uint8, int8, uint16, int16, int32, int64, bool, half, float, " - "double, complex64, " - "complex128, string}") + "T: {uint8, int8, uint16, int16, int32, int64, bool, half, bfloat16, " + "float, double, complex64, complex128, string}") .SetShapeFn([](InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle axis; @@ -1834,7 +1832,7 @@ this operation. REGISTER_OP("CheckNumerics") .Input("tensor: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .Attr("message: string") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( @@ -4565,12 +4563,12 @@ REGISTER_OP("Bitcast") .Output("output: type") // All supported dtypes are listed here to include qint16 and quint16. .Attr( - "T: {float, double, int64, int32, uint8, uint16, int8, int16," + "T: {bfloat16, float, double, int64, int32, uint8, uint16, int8, int16," " complex64, complex128, qint8, quint8, qint16, quint16, qint32," " half}") .Attr( - "type: {float, double, int64, int32, uint8, uint16, int8, int16," - " complex64, complex128, qint8, quint8, qint16, quint16, qint32," + "type: {bfloat16, float, double, int64, int32, uint8, uint16, int8, " + "int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32," " half}") .SetShapeFn([](InferenceContext* c) { ShapeHandle input = c->input(0); @@ -4782,7 +4780,7 @@ REGISTER_OP("QuantizeAndDequantize") .Attr("input_min: float = 0") .Attr("input_max: float = 0") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape) .Deprecated(22, "Replaced by QuantizeAndDequantizeV2") .Doc(R"doc( @@ -4798,7 +4796,7 @@ REGISTER_OP("QuantizeAndDequantizeV2") .Attr("num_bits: int = 8") .Attr("range_given: bool = false") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); @@ -4877,7 +4875,7 @@ REGISTER_OP("QuantizeAndDequantizeV3") .Attr("signed_input: bool = true") .Attr("range_given: bool = true") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 6bf226e7a5..be41531347 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -469,6 +469,24 @@ stop: corresponds to stop in python's xrange(). step: corresponds to step in python's xrange(). )doc"); +REGISTER_OP("RandomDataset") + .Input("seed: int64") + .Input("seed2: int64") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a Dataset that returns pseudorandom numbers. + +seed: A scalar seed for the random number generator. If either seed or + seed2 is set to be non-zero, the random number generator is seeded + by the given seed. Otherwise, a random seed is used. +seed2: A second scalar seed to avoid seed collision. +)doc"); + REGISTER_OP("ShuffleDataset") .Input("input_dataset: variant") .Input("buffer_size: int64") diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index df75caca37..45ebfa203b 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -85,7 +85,7 @@ REGISTER_OP("BatchMatMul") .Input("x: T") .Input("y: T") .Output("output: T") - .Attr("T: {half, float, double, int32, complex64, complex128}") + .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") .SetShapeFn([](InferenceContext* c) { @@ -184,7 +184,7 @@ _HostCast requires its input and produces its output in host memory. REGISTER_OP("Abs") .Input("x: T") .Output("y: T") - .Attr("T: {half, float, double, int32, int64}") + .Attr("T: {half, bfloat16, float, double, int32, int64}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Computes the absolute value of a tensor. @@ -210,29 +210,31 @@ value is computed as \\( \sqrt{a^2 + b^2}\\). )doc"); // Declares cwise unary operations signature: 't -> 't -#define UNARY() \ - Input("x: T") \ - .Output("y: T") \ - .Attr("T: {half, float, double, int32, int64, complex64, complex128}") \ +#define UNARY() \ + Input("x: T") \ + .Output("y: T") \ + .Attr( \ + "T: {half, bfloat16, float, double, int32, int64, complex64, " \ + "complex128}") \ .SetShapeFn(shape_inference::UnchangedShape) -#define UNARY_REAL() \ - Input("x: T") \ - .Output("y: T") \ - .Attr("T: {half, float, double}") \ +#define UNARY_REAL() \ + Input("x: T") \ + .Output("y: T") \ + .Attr("T: {half, bfloat16, float, double}") \ .SetShapeFn(shape_inference::UnchangedShape) -#define UNARY_COMPLEX() \ - Input("x: T") \ - .Output("y: T") \ - .Attr("T: {half, float, double, complex64, complex128}") \ +#define UNARY_COMPLEX() \ + Input("x: T") \ + .Output("y: T") \ + .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \ .SetShapeFn(shape_inference::UnchangedShape) -#define UNARY_GRADIENT_COMPLEX() \ - Input("y: T") \ - .Input("dy: T") \ - .Output("z: T") \ - .Attr("T: {half, float, double, complex64, complex128}") \ +#define UNARY_GRADIENT_COMPLEX() \ + Input("y: T") \ + .Input("dy: T") \ + .Output("z: T") \ + .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \ .SetShapeFn(shape_inference::UnchangedShape) REGISTER_OP("Neg") @@ -481,7 +483,7 @@ Computes atan of x element-wise. REGISTER_OP("IsNan") .Input("x: T") .Output("y: bool") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Returns which elements of x are NaN. @@ -494,7 +496,7 @@ Equivalent to np.isnan REGISTER_OP("IsInf") .Input("x: T") .Output("y: bool") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Returns which elements of x are Inf. @@ -507,7 +509,7 @@ Equivalent to np.isinf REGISTER_OP("IsFinite") .Input("x: T") .Output("y: bool") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Returns which elements of x are finite. @@ -520,7 +522,9 @@ Equivalent to np.isfinite REGISTER_OP("Sign") .Input("x: T") .Output("y: T") - .Attr("T: {half, float, double, int32, int64, complex64, complex128}") + .Attr( + "T: {half, bfloat16, float, double, int32, int64, complex64, " + "complex128}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Returns an element-wise indication of the sign of a number. @@ -533,7 +537,7 @@ For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`. REGISTER_OP("Floor") .Input("x: T") .Output("y: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Returns element-wise largest integer not greater than x. @@ -542,7 +546,7 @@ Returns element-wise largest integer not greater than x. REGISTER_OP("Ceil") .Input("x: T") .Output("y: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Returns element-wise smallest integer in not less than x. @@ -551,7 +555,7 @@ Returns element-wise smallest integer in not less than x. REGISTER_OP("Rint") .Input("x: T") .Output("y: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Returns element-wise integer closest to x. @@ -569,22 +573,23 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] // Declares cwise binary operations signature: 't, 't -> 't. -#define BINARY_MORE() \ - Input("x: T").Input("y: T").Output("z: T").Attr( \ - "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, " \ - "complex64, complex128}") +#define BINARY_MORE() \ + Input("x: T").Input("y: T").Output("z: T").Attr( \ + "T: {half, bfloat16, float, double, uint8, int8, uint16, int16, int32, " \ + "int64, complex64, complex128}") -#define BINARY_FEWER() \ - Input("x: T").Input("y: T").Output("z: T").Attr( \ - "T: {half, float, double, int32, int64, complex64, complex128}") +#define BINARY_FEWER() \ + Input("x: T").Input("y: T").Output("z: T").Attr( \ + "T: {half, bfloat16, float, double, int32, int64, complex64, " \ + "complex128}") REGISTER_OP("Add") .Input("x: T") .Input("y: T") .Output("z: T") .Attr( - "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " - "complex128, string}") + "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, " + "complex64, complex128, string}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns x + y element-wise. @@ -600,8 +605,8 @@ REGISTER_OP("AddV2") .Input("y: T") .Output("z: T") .Attr( - "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " - "complex128}") + "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, " + "complex64, complex128}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .SetIsAggregate() .SetIsCommutative() @@ -757,7 +762,7 @@ REGISTER_OP("Maximum") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {half, float, double, int32, int64}") + .Attr("T: {half, bfloat16, float, double, int32, int64}") .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( @@ -788,7 +793,7 @@ REGISTER_OP("Minimum") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {half, float, double, int32, int64}") + .Attr("T: {half, bfloat16, float, double, int32, int64}") .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( @@ -802,7 +807,7 @@ REGISTER_OP("Mod") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {int32, int64, float, double}") + .Attr("T: {int32, int64, bfloat16, float, double}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns element-wise remainder of division. This emulates C semantics in that @@ -817,7 +822,7 @@ REGISTER_OP("FloorMod") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {int32, int64, float, double}") + .Attr("T: {int32, int64, bfloat16, float, double}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns element-wise remainder of division. When `x < 0` xor `y < 0` is @@ -832,7 +837,7 @@ REGISTER_OP("TruncateMod") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {int32, int64, float, double}") + .Attr("T: {int32, int64, bfloat16, float, double}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns element-wise remainder of division. This emulates C semantics in that @@ -847,7 +852,9 @@ REGISTER_OP("Pow") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {half, float, double, int32, int64, complex64, complex128}") + .Attr( + "T: {half, bfloat16, float, double, int32, int64, complex64, " + "complex128}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Computes the power of one value to another. @@ -946,7 +953,7 @@ REGISTER_OP("Atan2") .Input("y: T") .Input("x: T") .Output("z: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Computes arctangent of `y/x` element-wise, respecting signs of the arguments. @@ -1064,15 +1071,15 @@ Returns the truth value of (x >= y) element-wise. // -------------------------------------------------------------------------- -#define EQUALITY_COMPARISON() \ - Input("x: T") \ - .Input("y: T") \ - .Output("z: bool") \ - .SetIsCommutative() \ - .Attr( \ - "T: {half, float, double, uint8, int8, int16, int32, int64, " \ - "complex64, " \ - "quint8, qint8, qint32, string, bool, complex128}") \ +#define EQUALITY_COMPARISON() \ + Input("x: T") \ + .Input("y: T") \ + .Output("z: bool") \ + .SetIsCommutative() \ + .Attr( \ + "T: {half, bfloat16, float, double, uint8, int8, int16, int32, " \ + "int64, complex64, quint8, qint8, qint32, string, bool, " \ + "complex128}") \ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) REGISTER_OP("Equal") @@ -1291,7 +1298,7 @@ REGISTER_OP("MatMul") .Output("product: T") .Attr("transpose_a: bool = false") .Attr("transpose_b: bool = false") - .Attr("T: {half, float, double, int32, complex64, complex128}") + .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}") .SetShapeFn(shape_inference::MatMulShape) .Doc(R"doc( Multiply the matrix "a" by the matrix "b". @@ -1811,10 +1818,11 @@ output: Has same shape as data, except for dimension 0 which REGISTER_OP("UnsortedSegmentSum") .Input("data: T") .Input("segment_ids: Tindices") - .Input("num_segments: int32") + .Input("num_segments: Tnumsegments") .Output("output: T") .Attr("T: numbertype") .Attr("Tindices: {int32,int64}") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") .SetShapeFn(UnsortedSegmentReductionShapeFn) .Doc(R"doc( Computes the sum along segments of a tensor. @@ -1849,10 +1857,11 @@ output: Has same shape as data, except for the first `segment_ids.rank` REGISTER_OP("UnsortedSegmentMax") .Input("data: T") .Input("segment_ids: Tindices") - .Input("num_segments: int32") + .Input("num_segments: Tnumsegments") .Output("output: T") .Attr("T: realnumbertype") .Attr("Tindices: {int32,int64}") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") .SetShapeFn(UnsortedSegmentReductionShapeFn) .Doc(R"doc( Computes the Max along segments of a tensor. @@ -2103,7 +2112,7 @@ REGISTER_OP("Range") .Input("limit: Tidx") .Input("delta: Tidx") .Output("output: Tidx") - .Attr("Tidx: {float, double, int32, int64} = DT_INT32") + .Attr("Tidx: {bfloat16, float, double, int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused), @@ -2158,7 +2167,7 @@ REGISTER_OP("LinSpace") .Input("stop: T") .Input("num: Tidx") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 654e890b57..102de94787 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -73,7 +73,7 @@ REGISTER_OP("AvgPool") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::AvgPoolShape) .Doc(R"doc( Performs average pooling on the input. @@ -101,7 +101,7 @@ REGISTER_OP("AvgPoolGrad") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); @@ -300,7 +300,7 @@ REGISTER_OP("FusedBatchNormV2") .Output("batch_variance: U") .Output("reserve_space_1: U") .Output("reserve_space_2: U") - .Attr("T: {half, float}") + .Attr("T: {half, bfloat16, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") .Attr("data_format: string = 'NHWC'") @@ -393,7 +393,7 @@ REGISTER_OP("FusedBatchNormGradV2") .Output("offset_backprop: U") .Output("reserve_space_3: U") .Output("reserve_space_4: U") - .Attr("T: {half, float}") + .Attr("T: {half, bfloat16, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") .Attr("data_format: string = 'NHWC'") @@ -508,11 +508,12 @@ REGISTER_OP("Conv2D") .Input("input: T") .Input("filter: T") .Output("output: T") - .Attr("T: {half, float}") + .Attr("T: {half, bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn(shape_inference::Conv2DShape) .Doc(R"doc( Computes a 2-D convolution given 4-D `input` and `filter` tensors. @@ -546,7 +547,7 @@ filter: A 4-D tensor of shape output: A 4-D tensor. The dimension order is determined by the value of `data_format`, see below for details. strides: 1-D tensor of length 4. The stride of the sliding window for each - dimension of `input`. The dimension order is determined by the value of + dimension of `input`. The dimension order is determined by the value of `data_format`, see below for details. padding: The type of padding algorithm to use. data_format: Specify the data format of the input and output data. With the @@ -554,6 +555,11 @@ data_format: Specify the data format of the input and output data. With the [batch, height, width, channels]. Alternatively, the format could be "NCHW", the data storage order of: [batch, channels, height, width]. +dilations: 1-D tensor of length 4. The dilation factor for each dimension of + `input`. If set to k > 1, there will be k-1 skipped cells between each + filter element on that dimension. The dimension order is determined by the + value of `data_format`, see above for details. Dilations in the batch and + depth dimensions must be 1. )doc"); REGISTER_OP("Conv2DBackpropInput") @@ -561,11 +567,12 @@ REGISTER_OP("Conv2DBackpropInput") .Input("filter: T") .Input("out_backprop: T") .Output("output: T") - .Attr("T: {half, float}") + .Attr("T: {half, bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); @@ -589,10 +596,15 @@ padding: The type of padding algorithm to use. output: 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient w.r.t. the input of the convolution. data_format: Specify the data format of the input and output data. With the - default format "NHWC", the data is stored in the order of: - [batch, in_height, in_width, in_channels]. - Alternatively, the format could be "NCHW", the data storage order of: - [batch, in_channels, in_height, in_width]. + default format "NHWC", the data is stored in the order of: + [batch, in_height, in_width, in_channels]. + Alternatively, the format could be "NCHW", the data storage order of: + [batch, in_channels, in_height, in_width]. +dilations: 1-D tensor of length 4. The dilation factor for each dimension of + `input`. If set to k > 1, there will be k-1 skipped cells between each filter + element on that dimension. The dimension order is determined by the value of + `data_format`, see above for details. Dilations in the batch and depth + dimensions must be 1. )doc"); // TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a @@ -603,11 +615,12 @@ REGISTER_OP("Conv2DBackpropFilter") .Input("filter_sizes: int32") .Input("out_backprop: T") .Output("output: T") - .Attr("T: {half, float}") + .Attr("T: {half, bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); @@ -632,10 +645,15 @@ output: 4-D with shape `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. the `filter` input of the convolution. data_format: Specify the data format of the input and output data. With the - default format "NHWC", the data is stored in the order of: - [batch, in_height, in_width, in_channels]. - Alternatively, the format could be "NCHW", the data storage order of: - [batch, in_channels, in_height, in_width]. + default format "NHWC", the data is stored in the order of: + [batch, in_height, in_width, in_channels]. + Alternatively, the format could be "NCHW", the data storage order of: + [batch, in_channels, in_height, in_width]. +dilations: 1-D tensor of length 4. The dilation factor for each dimension of + `input`. If set to k > 1, there will be k-1 skipped cells between each filter + element on that dimension. The dimension order is determined by the value of + `data_format`, see above for details. Dilations in the batch and depth + dimensions must be 1. )doc"); namespace { @@ -819,10 +837,11 @@ REGISTER_OP("DepthwiseConv2dNative") .Input("input: T") .Input("filter: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape) .Doc(R"doc( Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. @@ -845,7 +864,6 @@ for k in 0..in_channels-1 Must have `strides[0] = strides[3] = 1`. For the most common case of the same horizontal and vertices strides, `strides = [1, stride, stride, 1]`. - strides: 1-D of length 4. The stride of the sliding window for each dimension of `input`. padding: The type of padding algorithm to use. @@ -854,6 +872,11 @@ data_format: Specify the data format of the input and output data. With the [batch, height, width, channels]. Alternatively, the format could be "NCHW", the data storage order of: [batch, channels, height, width]. +dilations: 1-D tensor of length 4. The dilation factor for each dimension of + `input`. If set to k > 1, there will be k-1 skipped cells between each filter + element on that dimension. The dimension order is determined by the value of + `data_format`, see above for details. Dilations in the batch and depth + dimensions must be 1. )doc"); REGISTER_OP("DepthwiseConv2dNativeBackpropInput") @@ -861,10 +884,11 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropInput") .Input("filter: T") .Input("out_backprop: T") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); @@ -892,6 +916,11 @@ data_format: Specify the data format of the input and output data. With the [batch, height, width, channels]. Alternatively, the format could be "NCHW", the data storage order of: [batch, channels, height, width]. +dilations: 1-D tensor of length 4. The dilation factor for each dimension of + `input`. If set to k > 1, there will be k-1 skipped cells between each filter + element on that dimension. The dimension order is determined by the value of + `data_format`, see above for details. Dilations in the batch and depth + dimensions must be 1. output: 4-D with shape according to `data_format`. For example, if `data_format` is 'NHWC', output shape is `[batch, in_height, in_width, in_channels]`. Gradient w.r.t. the input of the @@ -903,10 +932,11 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter") .Input("filter_sizes: int32") .Input("out_backprop: T") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); @@ -935,6 +965,11 @@ data_format: Specify the data format of the input and output data. With the [batch, height, width, channels]. Alternatively, the format could be "NCHW", the data storage order of: [batch, channels, height, width]. +dilations: 1-D tensor of length 4. The dilation factor for each dimension of + `input`. If set to k > 1, there will be k-1 skipped cells between each filter + element on that dimension. The dimension order is determined by the value of + `data_format`, see above for details. Dilations in the batch and depth + dimensions must be 1. output: 4-D with shape `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. the `filter` input of the convolution. @@ -945,10 +980,11 @@ REGISTER_OP("Conv3D") .Input("input: T") .Input("filter: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") .SetShapeFn(shape_inference::Conv3DShape) .Doc(R"doc( Computes a 3-D convolution given 5-D `input` and `filter` tensors. @@ -970,6 +1006,11 @@ data_format: The data format of the input and output data. With the [batch, in_depth, in_height, in_width, in_channels]. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width]. +dilations: 1-D tensor of length 5. The dilation factor for each dimension of + `input`. If set to k > 1, there will be k-1 skipped cells between each + filter element on that dimension. The dimension order is determined by the + value of `data_format`, see above for details. Dilations in the batch and + depth dimensions must be 1. )doc"); REGISTER_OP("Conv3DBackpropInput") @@ -1032,10 +1073,11 @@ REGISTER_OP("Conv3DBackpropInputV2") .Input("filter: T") .Input("out_backprop: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); @@ -1061,6 +1103,11 @@ data_format: The data format of the input and output data. With the [batch, in_depth, in_height, in_width, in_channels]. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width]. +dilations: 1-D tensor of length 5. The dilation factor for each dimension of + `input`. If set to k > 1, there will be k-1 skipped cells between each + filter element on that dimension. The dimension order is determined by the + value of `data_format`, see above for details. Dilations in the batch and + depth dimensions must be 1. )doc"); @@ -1069,10 +1116,11 @@ REGISTER_OP("Conv3DBackpropFilterV2") .Input("filter_sizes: int32") .Input("out_backprop: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); @@ -1098,6 +1146,11 @@ data_format: The data format of the input and output data. With the [batch, in_depth, in_height, in_width, in_channels]. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width]. +dilations: 1-D tensor of length 5. The dilation factor for each dimension of + `input`. If set to k > 1, there will be k-1 skipped cells between each + filter element on that dimension. The dimension order is determined by the + value of `data_format`, see above for details. Dilations in the batch and + depth dimensions must be 1. )doc"); @@ -1110,7 +1163,7 @@ REGISTER_OP("AvgPool3D") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .SetShapeFn(shape_inference::Pool3DShape) .Doc(R"doc( Performs 3D average pooling on the input. @@ -1137,7 +1190,7 @@ REGISTER_OP("AvgPool3DGrad") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); @@ -1172,7 +1225,7 @@ REGISTER_OP("MaxPool3D") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) - .Attr("T: {float}") + .Attr("T: {bfloat16, float}") .SetShapeFn(shape_inference::Pool3DShape) .Doc(R"doc( Performs 3D max pooling on the input. @@ -1200,8 +1253,8 @@ REGISTER_OP("MaxPool3DGrad") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) - .Attr("T: {float} = DT_FLOAT") - .Attr("TInput: {float} = DT_FLOAT") + .Attr("T: {bfloat16, float} = DT_FLOAT") + .Attr("TInput: {bfloat16, float} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { return UnchangedShapeWithRank(c, 5); }) @@ -1266,7 +1319,7 @@ data_format: The data format of the input and output data. With the REGISTER_OP("L2Loss") .Input("t: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc( L2 Loss. @@ -1288,7 +1341,7 @@ REGISTER_OP("LRN") .Attr("bias: float = 1.0") .Attr("alpha: float = 1.0") .Attr("beta: float = 0.5") - .Attr("T: {float, half} = DT_FLOAT") + .Attr("T: {half, bfloat16, float} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { return UnchangedShapeWithRank(c, 4); }) @@ -1323,7 +1376,7 @@ REGISTER_OP("LRNGrad") .Attr("bias: float = 1.0") .Attr("alpha: float = 1.0") .Attr("beta: float = 0.5") - .Attr("T: {float, half} = DT_FLOAT") + .Attr("T: {half, bfloat16, float} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads @@ -1349,8 +1402,8 @@ output: The gradients for LRN. REGISTER_OP("MaxPool") .Attr( - "T: {float, double, int32, int64, uint8, int16, int8, uint16, " - "half, qint8} = DT_FLOAT") + "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, " + "uint16, qint8} = DT_FLOAT") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) @@ -1376,8 +1429,8 @@ output: The max pooled output tensor. REGISTER_OP("MaxPoolV2") .Attr( - "T: {float, double, int32, int64, uint8, int16, int8, uint16, " - "half, qint8} = DT_FLOAT") + "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, " + "uint16, qint8} = DT_FLOAT") .Attr(GetPaddingAttrString()) .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Input("input: T") @@ -1860,7 +1913,7 @@ backprops: The gradients: REGISTER_OP("Elu") .Input("features: T") .Output("activations: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. @@ -1873,7 +1926,7 @@ REGISTER_OP("EluGrad") .Input("gradients: T") .Input("outputs: T") .Output("backprops: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( Computes gradients for the exponential linear (Elu) operation. @@ -1887,7 +1940,7 @@ backprops: The gradients: `gradients * (outputs + 1)` if outputs < 0, REGISTER_OP("Selu") .Input("features: T") .Output("activations: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` @@ -1900,7 +1953,7 @@ REGISTER_OP("SeluGrad") .Input("gradients: T") .Input("outputs: T") .Output("backprops: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( Computes gradients for the scaled exponential linear (Selu) operation. @@ -1962,7 +2015,7 @@ backprops: The gradients: `gradients / (1 + abs(features)) ** 2`. REGISTER_OP("Softmax") .Input("logits: T") .Output("softmax: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }) @@ -1982,7 +2035,7 @@ softmax: Same shape as `logits`. REGISTER_OP("LogSoftmax") .Input("logits: T") .Output("logsoftmax: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }) @@ -2004,7 +2057,7 @@ REGISTER_OP("SoftmaxCrossEntropyWithLogits") .Input("labels: T") .Output("loss: T") .Output("backprop: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); @@ -2033,7 +2086,7 @@ REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits") .Input("labels: Tlabels") .Output("loss: T") .Output("backprop: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .Attr("Tlabels: {int32, int64} = DT_INT64") .SetShapeFn([](InferenceContext* c) { ShapeHandle features; @@ -2613,6 +2666,7 @@ REGISTER_OP("QuantizedConv2D") .Attr("out_type: quantizedtype = DT_QINT32") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2641,7 +2695,11 @@ min_filter: The float value that the lowest quantized filter value represents. max_filter: The float value that the highest quantized filter value represents. min_output: The float value that the lowest quantized output value represents. max_output: The float value that the highest quantized output value represents. - +dilations: 1-D tensor of length 4. The dilation factor for each dimension of + `input`. If set to k > 1, there will be k-1 skipped cells between each + filter element on that dimension. The dimension order is determined by the + value of `data_format`, see above for details. Dilations in the batch and + depth dimensions must be 1. )doc"); REGISTER_OP("QuantizedMaxPool") diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc index 2429171fa9..31d9c82e53 100644 --- a/tensorflow/core/ops/random_ops.cc +++ b/tensorflow/core/ops/random_ops.cc @@ -29,7 +29,7 @@ REGISTER_OP("RandomUniform") .Output("output: dtype") .Attr("seed: int = 0") .Attr("seed2: int = 0") - .Attr("dtype: {half,float,double}") + .Attr("dtype: {half,bfloat16,float,double}") .Attr("T: {int32, int64}") .SetShapeFn(shape_inference::RandomShape) .Doc(R"doc( @@ -87,7 +87,7 @@ REGISTER_OP("RandomStandardNormal") .Output("output: dtype") .Attr("seed: int = 0") .Attr("seed2: int = 0") - .Attr("dtype: {half,float,double}") + .Attr("dtype: {half,bfloat16,float,double}") .Attr("T: {int32, int64}") .SetShapeFn(shape_inference::RandomShape) .Doc(R"doc( @@ -115,7 +115,7 @@ REGISTER_OP("ParameterizedTruncatedNormal") .Output("output: dtype") .Attr("seed: int = 0") .Attr("seed2: int = 0") - .Attr("dtype: {half,float,double}") + .Attr("dtype: {half,bfloat16,float,double}") .Attr("T: {int32, int64}") .SetShapeFn(shape_inference::RandomShape) .Doc(R"doc( @@ -145,7 +145,7 @@ REGISTER_OP("TruncatedNormal") .Output("output: dtype") .Attr("seed: int = 0") .Attr("seed2: int = 0") - .Attr("dtype: {half,float,double}") + .Attr("dtype: {half,bfloat16,float,double}") .Attr("T: {int32, int64}") .SetShapeFn(shape_inference::RandomShape) .Doc(R"doc( @@ -201,10 +201,11 @@ REGISTER_OP("Multinomial") .SetIsStateful() .Input("logits: T") .Input("num_samples: int32") - .Output("output: int64") + .Output("output: output_dtype") .Attr("seed: int = 0") .Attr("seed2: int = 0") .Attr("T: realnumbertype") + .Attr("output_dtype: {int32, int64} = DT_INT64") .SetShapeFn([](InferenceContext* c) { ShapeHandle logits_shape; ShapeHandle unused; diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index cdfbec85cf..bf9e673e8e 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -204,7 +204,10 @@ Status VariableShapeShapeFn(InferenceContext* c) { if (handle_data == nullptr || handle_data->empty()) { return errors::InvalidArgument("Handle doesn't have shape information."); } - c->set_output(0, (*handle_data)[0].shape); + ShapeHandle var_shape = (*handle_data)[0].shape; + int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape) + : InferenceContext::kUnknownDim; + c->set_output(0, c->Vector(rank)); return Status::OK(); } diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc index 8414519f0b..772e2531dc 100644 --- a/tensorflow/core/ops/sparse_ops.cc +++ b/tensorflow/core/ops/sparse_ops.cc @@ -256,6 +256,48 @@ REGISTER_OP("DeserializeSparse") .Doc(R"doc( Deserialize `SparseTensor` objects. +The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where +the last dimension stores serialized `SparseTensor` objects and the other N +dimensions (N >= 0) correspond to a batch. The ranks of the original +`SparseTensor` objects must all match. When the final `SparseTensor` is +created, its rank is the rank of the incoming `SparseTensor` objects plus N; +the sparse tensors have been concatenated along new dimensions, one for each +batch. + +The output `SparseTensor` object's shape values for the original dimensions +are the max across the input `SparseTensor` objects' shape values for the +corresponding dimensions. The new dimensions match the size of the batch. + +The input `SparseTensor` objects' indices are assumed ordered in +standard lexicographic order. If this is not the case, after this +step run `SparseReorder` to restore index ordering. + +For example, if the serialized input is a `[2 x 3]` matrix representing two +original `SparseTensor` objects: + + index = [ 0] + [10] + [20] + values = [1, 2, 3] + shape = [50] + +and + + index = [ 2] + [10] + values = [4, 5] + shape = [30] + +then the final deserialized `SparseTensor` will be: + + index = [0 0] + [0 10] + [0 20] + [1 2] + [1 10] + values = [1, 2, 3, 4, 5] + shape = [2 50] + serialized_sparse: The serialized `SparseTensor` objects. The last dimension must have 3 columns. dtype: The `dtype` of the serialized `SparseTensor` objects. diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index da5f091e9f..5b1f5d2477 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -513,6 +513,62 @@ output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done. )doc"); +REGISTER_OP("ResourceScatterNdUpdate") + .Input("ref: resource") + .Input("indices: Tindices") + .Input("updates: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = true") + .SetShapeFn(shape_inference::ScatterNdUpdateShape) + .Doc(R"doc( +Applies sparse `updates` to individual values or slices within a given +variable according to `indices`. + +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +``` + +For example, say we want to update 4 scattered elements to a rank-1 tensor to +8 elements. In Python, that update would look like this: + +```python + ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1] ,[7]]) + updates = tf.constant([9, 10, 11, 12]) + update = tf.scatter_nd_update(ref, indices, updates) + with tf.Session() as sess: + print sess.run(update) +``` + +The resulting update to ref would look like this: + + [1, 11, 3, 10, 9, 6, 7, 12] + +See @{tf.scatter_nd} for more details about how to make updates to +slices. + +ref: A resource handle. Must be from a VarHandleOp. +indices: A Tensor. Must be one of the following types: int32, int64. + A tensor of indices into ref. +updates: A Tensor. Must have the same type as ref. A tensor of updated + values to add to ref. +use_locking: An optional bool. Defaults to True. If True, the assignment will + be protected by a lock; otherwise the behavior is undefined, + but may exhibit less contention. +)doc"); + REGISTER_OP("ScatterNdAdd") .Input("ref: Ref(T)") .Input("indices: Tindices") diff --git a/tensorflow/core/platform/cloud/curl_http_request_test.cc b/tensorflow/core/platform/cloud/curl_http_request_test.cc index 6c0f081852..d476a1a4db 100644 --- a/tensorflow/core/platform/cloud/curl_http_request_test.cc +++ b/tensorflow/core/platform/cloud/curl_http_request_test.cc @@ -263,7 +263,6 @@ TEST(CurlHttpRequestTest, GetRequest) { std::vector<char> scratch; scratch.insert(scratch.begin(), kTestContent.begin(), kTestContent.end()); - StringPiece result; scratch.reserve(100); TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com")); @@ -594,7 +593,6 @@ TEST(CurlHttpRequestTest, ErrorReturnsNoResponse) { std::vector<char> scratch; scratch.insert(scratch.begin(), kTestContent.begin(), kTestContent.end()); - StringPiece result; scratch.reserve(100); TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com")); diff --git a/tensorflow/core/platform/cloud/file_block_cache.cc b/tensorflow/core/platform/cloud/file_block_cache.cc index a472ae52fc..e1afc7b308 100644 --- a/tensorflow/core/platform/cloud/file_block_cache.cc +++ b/tensorflow/core/platform/cloud/file_block_cache.cc @@ -181,7 +181,9 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n, // The requested offset is at or beyond the end of the file. This can // happen if `offset` is not block-aligned, and the read returns the last // block in the file, which does not extend all the way out to `offset`. - return errors::OutOfRange("EOF at offset ", offset); + return errors::OutOfRange("EOF at offset ", offset, " in file ", filename, + " at position ", pos, "with data size ", + data.size()); } auto begin = data.begin(); if (offset > pos) { diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 54d38fe962..45e9b05092 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -697,6 +697,9 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://", bucket, "/", object); + VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ " + << offset << " of size: " << out->size(); + if (out->size() < block_size()) { // Check stat cache to see if we encountered an interrupted read. FileStatistics stat; @@ -706,6 +709,8 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, "File contents are inconsistent for file: %s @ %lu.", filename.c_str(), offset)); } + VLOG(2) << "Successful integrity check for: gs://" << bucket << "/" + << object << " @ " << offset; } } @@ -868,6 +873,11 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket, TF_RETURN_IF_ERROR(GetStringValue(root, "updated", &updated)); TF_RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->mtime_nsec))); + VLOG(1) << "Stat of: gs://" << bucket << "/" << object << " -- " + << " length: " << stat->length + << "; mtime_nsec: " << stat->mtime_nsec + << "; updated: " << updated; + stat->is_directory = false; return Status::OK(); }; diff --git a/tensorflow/core/profiler/g3doc/options.md b/tensorflow/core/profiler/g3doc/options.md index 4c73e372e3..dd12f76d6f 100644 --- a/tensorflow/core/profiler/g3doc/options.md +++ b/tensorflow/core/profiler/g3doc/options.md @@ -60,11 +60,14 @@ Currently, profiler only tracks the allocation of memory. As a result, the accumulated memory request is uaually larger than the peak memory of the overall model. -bytes: The memory allocations requested by the operation. -peak_bytes: The peak requested memory (not de-allocated) by the operation. -residual_bytes: The memory requested by the operation and not de-allocated +It's recommended to generate timeline to see the allocator memory usage over +time. + +`bytes`: The memory allocations requested by the operation. +`peak_bytes`: The peak requested memory (not de-allocated) by the operation. +`residual_bytes`: The memory requested by the operation and not de-allocated when Compute finishes. -output_bytes: The memory output by the operation. It's not necessarily requested +`output_bytes`: The memory output by the operation. It's not necessarily requested by the current operation. For example, it can be a tensor forwarded from input to output, with in-place mutation. diff --git a/tensorflow/core/profiler/internal/tfprof_node.cc b/tensorflow/core/profiler/internal/tfprof_node.cc index 671b65d708..5cd1050bcc 100644 --- a/tensorflow/core/profiler/internal/tfprof_node.cc +++ b/tensorflow/core/profiler/internal/tfprof_node.cc @@ -139,6 +139,25 @@ void ExecStep::AddMemoryStats(const string& dev, exec_.accelerator_persistent_bytes() + step_stat.memory_stats().device_persistent_memory_size()); } + + // TODO(xpan): Make this more accurate: + // High level: Memory tracking is suspicous and requires large scale + // clean up. + // Investigte the memory usage difference between CPU/GPU with OpViewTest. + // + // 1. OpKernelConstruction::allocate_xxx is not traced. Below, we only + // discuss OpKernelContext-related allocations. + // 2. allocate_output calls allocate_tensor, which is properly tracked in + // 'NodeExecStats.memory'. + // 3. allocate_temp is only tracked through record_xxx_temp. It appears + // in 'NodeExecStats.memory_stats'. + // 4. allocate_persistent calls allocate_tensor, which is properly tracked + // in 'NodeExecStats.memory'. However, there is no way to count it as + // persistent now. + // 5. record_xxx_persistent is called when allocate_persistent + // is not used and hence tracks some complementary bytes. It appears in + // 'NodeExecStats.memory_stats'. It's suspicious. But we should + // use it now since it covers constant op. int64 residual_bytes = 0; int64 requested_bytes = 0; int64 peak_bytes = 0; @@ -147,6 +166,15 @@ void ExecStep::AddMemoryStats(const string& dev, requested_bytes += mem.total_bytes(); peak_bytes += mem.peak_bytes(); } + residual_bytes += + exec_.host_persistent_bytes() + exec_.accelerator_persistent_bytes(); + requested_bytes += exec_.host_persistent_bytes() + + exec_.accelerator_persistent_bytes() + + exec_.host_temp_bytes() + exec_.accelerator_temp_bytes(); + peak_bytes += exec_.host_persistent_bytes() + + exec_.accelerator_persistent_bytes() + exec_.host_temp_bytes() + + exec_.accelerator_temp_bytes(); + exec_.set_requested_bytes(requested_bytes); exec_.set_residual_bytes(residual_bytes); exec_.set_peak_bytes(peak_bytes); diff --git a/tensorflow/core/profiler/internal/tfprof_node.h b/tensorflow/core/profiler/internal/tfprof_node.h index e2d0563a07..77c14cb792 100644 --- a/tensorflow/core/profiler/internal/tfprof_node.h +++ b/tensorflow/core/profiler/internal/tfprof_node.h @@ -593,17 +593,11 @@ class TFGraphNode { int64 accelerator_persistent_bytes() const { int64 persistent_bytes = 0; for (const auto& exec : execs_) { - persistent_bytes += exec.second.accelerator_persistent_bytes(); + persistent_bytes = std::max(persistent_bytes, + exec.second.accelerator_persistent_bytes()); } return persistent_bytes; } - int64 host_persistent_bytes(int64 step) const { - auto exec = execs_.find(step); - if (exec == execs_.end()) { - return 0; - } - return exec->second.host_persistent_bytes(); - } const std::map<int32, std::pair<int64, uint64>>& output_memory( int64 step) const { auto exec = execs_.find(step); diff --git a/tensorflow/core/profiler/internal/tfprof_show_test.cc b/tensorflow/core/profiler/internal/tfprof_show_test.cc index 1f19f8c322..98773ae19e 100644 --- a/tensorflow/core/profiler/internal/tfprof_show_test.cc +++ b/tensorflow/core/profiler/internal/tfprof_show_test.cc @@ -105,12 +105,13 @@ TEST_F(TFProfShowTest, DumpScopeMode) { "node name | # parameters | # float_ops | requested bytes | peak bytes | " "residual bytes | output bytes | total execution time | accelerator " "execution time | cpu execution time\n_TFProfRoot (--/451 params, --/0 " - "flops, --/0B, --/0B, --/0B, --/2.56KB, --/13us, --/0us, --/13us)\n DW " - "(3x3x3x6, 162/162 params, 0/0 flops, 0B/0B, 0B/0B, 0B/0B, " - "1.28KB/1.28KB, 2us/2us, 0us/0us, 2us/2us)\n DW2 (2x2x6x12, 288/288 " - "params, 0/0 flops, 0B/0B, 0B/0B, 0B/0B, 1.28KB/1.28KB, 11us/11us, " - "0us/0us, 11us/11us)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, 0B/0B, " - "0B/0B, 0B/0B, 0us/0us, 0us/0us, 0us/0us)\n", + "flops, --/2.56KB, --/2.56KB, --/2.56KB, --/2.56KB, --/13us, --/0us, " + "--/13us)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 1.28KB/1.28KB, " + "1.28KB/1.28KB, 1.28KB/1.28KB, 1.28KB/1.28KB, 2us/2us, 0us/0us, " + "2us/2us)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.28KB/1.28KB, " + "1.28KB/1.28KB, 1.28KB/1.28KB, 1.28KB/1.28KB, 11us/11us, 0us/0us, " + "11us/11us)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, 0B/0B, 0B/0B, " + "0B/0B, 0us/0us, 0us/0us, 0us/0us)\n", dump_str); EXPECT_EQ(dump_str, TestToFromProto("scope", opts)); @@ -178,22 +179,22 @@ TEST_F(TFProfShowTest, DumpOpMode) { EXPECT_EQ( "nodename|requestedbytes|totalexecutiontime|acceleratorexecutiontime|" "cpuexecutiontime|#parameters|#float_ops|opoccurrence(run|defined)|" - "inputshapes\nVariableV20B(0.00%,0.00%),13us(100.00%,0.26%),0us(100.00%," - "0.00%),13us(100.00%,0.29%),451params(100.00%,100.00%),0float_ops(100.00%" - ",0.00%),2|3\n\ninput_type:\t(run*2|defined*3)\texec_time:13us\n\nAdd0B(" - "0.00%,0.00%),0us(99.74%,0.00%),0us(100.00%,0.00%),0us(99.71%,0.00%)," - "0params(0.00%,0.00%),0float_ops(100.00%,0.00%),0|3\n\ninput_type:0:1," - "\t1:1\t(run*0|defined*1)\texec_time:0us\ninput_type:0:2x2x6x12,\t1:1\t(" - "run*0|defined*1)\texec_time:0us\ninput_type:0:3x3x3x6,\t1:1\t(run*0|" - "defined*1)\texec_time:0us\n\nAssign0B(0.00%,0.00%),0us(99.74%,0.00%)," - "0us(100.00%,0.00%),0us(99.71%,0.00%),0params(0.00%,0.00%),0float_ops(" - "100.00%,0.00%),0|3\n\ninput_type:0:1,\t1:1\t(run*0|defined*1)\texec_" + "inputshapes\nVariableV22.56KB(100.00%,8.40%),13us(100.00%,0.26%),0us(" + "100.00%,0.00%),13us(100.00%,0.29%),451params(100.00%,100.00%),0float_" + "ops(100.00%,0.00%),2|3\n\ninput_type:\t(run*2|defined*3)\texec_time:" + "13us\n\nAdd0B(0.00%,0.00%),0us(99.74%,0.00%),0us(100.00%,0.00%),0us(99." + "71%,0.00%),0params(0.00%,0.00%),0float_ops(100.00%,0.00%),0|3\n\ninput_" + "type:0:1,\t1:1\t(run*0|defined*1)\texec_time:0us\ninput_type:0:2x2x6x12," + "\t1:1\t(run*0|defined*1)\texec_time:0us\ninput_type:0:3x3x3x6,\t1:1\t(" + "run*0|defined*1)\texec_time:0us\n\nAssign0B(0.00%,0.00%),0us(99.74%,0." + "00%),0us(100.00%,0.00%),0us(99.71%,0.00%),0params(0.00%,0.00%),0float_" + "ops(100.00%,0.00%),0|3\n\ninput_type:0:1,\t1:1\t(run*0|defined*1)\texec_" "time:0us\ninput_type:0:2x2x6x12,\t1:2x2x6x12\t(run*0|defined*1)\texec_" "time:0us\ninput_type:0:3x3x3x6,\t1:3x3x3x6\t(run*0|defined*1)\texec_" "time:0us\n\nConst0B(0.00%,0.00%),2us(99.74%,0.04%),0us(100.00%,0.00%)," "2us(99.71%,0.04%),0params(0.00%,0.00%),0float_ops(100.00%,0.00%),1|" - "10\n\ninput_type:\t(run*1|defined*10)\texec_time:2us\n\nConv2D14.59KB(" - "100.00%,100.00%),4.89ms(99.70%,98.87%),404us(100.00%,100.00%),4.49ms(99." + "10\n\ninput_type:\t(run*1|defined*10)\texec_time:2us\n\nConv2D27.90KB(" + "91.60%,91.60%),4.89ms(99.70%,98.87%),404us(100.00%,100.00%),4.49ms(99." "67%,98.77%),0params(0.00%,0.00%),10.44kfloat_ops(100.00%,100.00%),2|" "2\n\ninput_type:0:2x3x3x6,\t1:2x2x6x12\t(run*1|defined*1)\texec_time:" "597us\ninput_type:0:2x6x6x3,\t1:3x3x3x6\t(run*1|defined*1)\texec_time:4." diff --git a/tensorflow/core/profiler/internal/tfprof_stats_test.cc b/tensorflow/core/profiler/internal/tfprof_stats_test.cc index 2f2101d76b..b86a83cb1b 100644 --- a/tensorflow/core/profiler/internal/tfprof_stats_test.cc +++ b/tensorflow/core/profiler/internal/tfprof_stats_test.cc @@ -89,21 +89,27 @@ TEST_F(TFProfStatsTest, CustomOpType) { GraphNodeProto expected; CHECK(protobuf::TextFormat::ParseFromString( - "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_parameters: " - "451\nchildren {\n name: \"DW\"\n exec_micros: 2\n parameters: 162\n " - "total_exec_micros: 2\n total_parameters: 162\n devices: " + "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_requested_bytes: " + "2560\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n exec_micros: " + "2\n requested_bytes: 1280\n parameters: 162\n total_exec_micros: 2\n " + " total_requested_bytes: 1280\n total_parameters: 162\n devices: " "\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 2\n " "total_cpu_exec_micros: 2\n run_count: 1\n total_run_count: 1\n " - "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: " - "1280\n}\nchildren {\n name: \"DW2\"\n exec_micros: 11\n parameters: " - "288\n total_exec_micros: 11\n total_parameters: 288\n devices: " + "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n " + " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: " + "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"DW2\"\n " + "exec_micros: 11\n requested_bytes: 1280\n parameters: 288\n " + "total_exec_micros: 11\n total_requested_bytes: 1280\n " + "total_parameters: 288\n devices: " "\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 11\n " "total_cpu_exec_micros: 11\n run_count: 1\n total_run_count: 1\n " - "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: " - "1280\n}\nchildren {\n name: \"ScalarW\"\n parameters: 1\n " - "total_parameters: 1\n total_definition_count: " + "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n " + " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: " + "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"ScalarW\"\n " + "parameters: 1\n total_parameters: 1\n total_definition_count: " "1\n}\ntotal_cpu_exec_micros: 13\ntotal_run_count: " - "2\ntotal_definition_count: 3\ntotal_output_bytes: 2560\n", + "2\ntotal_definition_count: 3\ntotal_peak_bytes: " + "2560\ntotal_residual_bytes: 2560\ntotal_output_bytes: 2560\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); @@ -119,21 +125,27 @@ TEST_F(TFProfStatsTest, CheckPointOpType) { GraphNodeProto expected; CHECK(protobuf::TextFormat::ParseFromString( - "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_parameters: " - "451\nchildren {\n name: \"DW\"\n exec_micros: 2\n parameters: 162\n " - "total_exec_micros: 2\n total_parameters: 162\n devices: " + "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_requested_bytes: " + "2560\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n exec_micros: " + "2\n requested_bytes: 1280\n parameters: 162\n total_exec_micros: 2\n " + " total_requested_bytes: 1280\n total_parameters: 162\n devices: " "\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 2\n " "total_cpu_exec_micros: 2\n run_count: 1\n total_run_count: 1\n " - "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: " - "1280\n}\nchildren {\n name: \"DW2\"\n exec_micros: 11\n parameters: " - "288\n total_exec_micros: 11\n total_parameters: 288\n devices: " + "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n " + " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: " + "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"DW2\"\n " + "exec_micros: 11\n requested_bytes: 1280\n parameters: 288\n " + "total_exec_micros: 11\n total_requested_bytes: 1280\n " + "total_parameters: 288\n devices: " "\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 11\n " "total_cpu_exec_micros: 11\n run_count: 1\n total_run_count: 1\n " - "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: " - "1280\n}\nchildren {\n name: \"ScalarW\"\n parameters: 1\n " - "total_parameters: 1\n total_definition_count: " + "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n " + " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: " + "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"ScalarW\"\n " + "parameters: 1\n total_parameters: 1\n total_definition_count: " "1\n}\ntotal_cpu_exec_micros: 13\ntotal_run_count: " - "2\ntotal_definition_count: 3\ntotal_output_bytes: 2560\n", + "2\ntotal_definition_count: 3\ntotal_peak_bytes: " + "2560\ntotal_residual_bytes: 2560\ntotal_output_bytes: 2560\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); @@ -150,7 +162,7 @@ TEST_F(TFProfStatsTest, TestGraph) { GraphNodeProto expected; CHECK(protobuf::TextFormat::ParseFromString( "name: \"_TFProfRoot\"\ntotal_exec_micros: 4945\ntotal_requested_bytes: " - "14592\ntotal_parameters: 451\nchildren {\n name: " + "30464\ntotal_parameters: 451\nchildren {\n name: " "\"DW/Initializer/random_normal/mul\"\n children {\n name: " "\"DW/Initializer/random_normal/RandomStandardNormal\"\n children {\n " " name: \"DW/Initializer/random_normal/shape\"\n " @@ -166,7 +178,7 @@ TEST_F(TFProfStatsTest, TestGraph) { "4\n}\ntotal_float_ops: 10440\ntotal_accelerator_exec_micros: " "404\ntotal_cpu_exec_micros: 4541\ntotal_run_count: " "6\ntotal_definition_count: 32\ntotal_peak_bytes: " - "9984\ntotal_residual_bytes: 1280\ntotal_output_bytes: 4864\n", + "25856\ntotal_residual_bytes: 3840\ntotal_output_bytes: 4864\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); @@ -181,9 +193,9 @@ TEST_F(TFProfStatsTest, TestFloatOps) { GraphNodeProto expected; CHECK(protobuf::TextFormat::ParseFromString( "name: \"_TFProfRoot\"\ntotal_exec_micros: 4945\ntotal_requested_bytes: " - "14592\ntotal_parameters: 451\nchildren {\n name: \"Conv2D\"\n " - "exec_micros: 4292\n requested_bytes: 9472\n total_exec_micros: 4292\n " - " total_requested_bytes: 9472\n devices: " + "30464\ntotal_parameters: 451\nchildren {\n name: \"Conv2D\"\n " + "exec_micros: 4292\n requested_bytes: 18176\n total_exec_micros: " + "4292\n total_requested_bytes: 18176\n devices: " "\"/job:localhost/replica:0/task:0/gpu:0\"\n float_ops: 5832\n " "total_float_ops: 5832\n input_shapes {\n key: 0\n value {\n " "dim {\n size: 2\n }\n dim {\n size: 6\n " @@ -194,11 +206,11 @@ TEST_F(TFProfStatsTest, TestFloatOps) { "6\n }\n }\n }\n accelerator_exec_micros: 226\n " "cpu_exec_micros: 4066\n total_accelerator_exec_micros: 226\n " "total_cpu_exec_micros: 4066\n run_count: 1\n total_run_count: 1\n " - "total_definition_count: 1\n peak_bytes: 5888\n residual_bytes: 768\n " - "output_bytes: 768\n total_peak_bytes: 5888\n total_residual_bytes: " + "total_definition_count: 1\n peak_bytes: 14592\n residual_bytes: 768\n " + " output_bytes: 768\n total_peak_bytes: 14592\n total_residual_bytes: " "768\n total_output_bytes: 768\n}\nchildren {\n name: \"Conv2D_1\"\n " - "exec_micros: 597\n requested_bytes: 5120\n total_exec_micros: 597\n " - "total_requested_bytes: 5120\n devices: " + "exec_micros: 597\n requested_bytes: 9728\n total_exec_micros: 597\n " + "total_requested_bytes: 9728\n devices: " "\"/job:localhost/replica:0/task:0/gpu:0\"\n float_ops: 4608\n " "total_float_ops: 4608\n input_shapes {\n key: 0\n value {\n " "dim {\n size: 2\n }\n dim {\n size: 3\n " @@ -209,12 +221,12 @@ TEST_F(TFProfStatsTest, TestFloatOps) { "12\n }\n }\n }\n accelerator_exec_micros: 178\n " "cpu_exec_micros: 419\n total_accelerator_exec_micros: 178\n " "total_cpu_exec_micros: 419\n run_count: 1\n total_run_count: 1\n " - "total_definition_count: 1\n peak_bytes: 4096\n residual_bytes: 512\n " - "output_bytes: 512\n total_peak_bytes: 4096\n total_residual_bytes: " + "total_definition_count: 1\n peak_bytes: 8704\n residual_bytes: 512\n " + "output_bytes: 512\n total_peak_bytes: 8704\n total_residual_bytes: " "512\n total_output_bytes: 512\n}\ntotal_float_ops: " "10440\ntotal_accelerator_exec_micros: 404\ntotal_cpu_exec_micros: " "4541\ntotal_run_count: 6\ntotal_definition_count: 35\ntotal_peak_bytes: " - "9984\ntotal_residual_bytes: 1280\ntotal_output_bytes: 4864\n", + "25856\ntotal_residual_bytes: 3840\ntotal_output_bytes: 4864\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); @@ -231,9 +243,9 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) { GraphNodeProto expected; CHECK(protobuf::TextFormat::ParseFromString( "name: \"_TFProfRoot\"\ntotal_exec_micros: 597\ntotal_requested_bytes: " - "5120\nchildren {\n name: \"Conv2D_1\"\n exec_micros: 597\n " - "requested_bytes: 5120\n total_exec_micros: 597\n " - "total_requested_bytes: 5120\n devices: " + "9728\nchildren {\n name: \"Conv2D_1\"\n exec_micros: 597\n " + "requested_bytes: 9728\n total_exec_micros: 597\n " + "total_requested_bytes: 9728\n devices: " "\"/job:localhost/replica:0/task:0/gpu:0\"\n float_ops: 4608\n " "total_float_ops: 4608\n input_shapes {\n key: 0\n value {\n " "dim {\n size: 2\n }\n dim {\n size: 3\n " @@ -244,12 +256,12 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) { "12\n }\n }\n }\n accelerator_exec_micros: 178\n " "cpu_exec_micros: 419\n total_accelerator_exec_micros: 178\n " "total_cpu_exec_micros: 419\n run_count: 1\n total_run_count: 1\n " - "total_definition_count: 1\n peak_bytes: 4096\n residual_bytes: 512\n " - "output_bytes: 512\n total_peak_bytes: 4096\n total_residual_bytes: " + "total_definition_count: 1\n peak_bytes: 8704\n residual_bytes: 512\n " + "output_bytes: 512\n total_peak_bytes: 8704\n total_residual_bytes: " "512\n total_output_bytes: 512\n}\ntotal_float_ops: " "4608\ntotal_accelerator_exec_micros: 178\ntotal_cpu_exec_micros: " "419\ntotal_run_count: 1\ntotal_definition_count: 2\ntotal_peak_bytes: " - "4096\ntotal_residual_bytes: 512\ntotal_output_bytes: 512\n", + "8704\ntotal_residual_bytes: 512\ntotal_output_bytes: 512\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); @@ -265,8 +277,9 @@ TEST_F(TFProfStatsTest, TestShowTensorValue) { GraphNodeProto expected; CHECK(protobuf::TextFormat::ParseFromString( "name: \"_TFProfRoot\"\ntotal_exec_micros: 4945\ntotal_requested_bytes: " - "14592\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n " - "exec_micros: 2\n parameters: 162\n total_exec_micros: 2\n " + "30464\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n " + "exec_micros: 2\n requested_bytes: 1280\n parameters: 162\n " + "total_exec_micros: 2\n total_requested_bytes: 1280\n " "total_parameters: 162\n devices: " "\"/job:localhost/replica:0/task:0/gpu:0\"\n tensor_value {\n dtype: " "DT_FLOAT\n value_double: -0.000534315\n value_double: " @@ -351,11 +364,13 @@ TEST_F(TFProfStatsTest, TestShowTensorValue) { "value_double: 0.000374641\n value_double: -0.00149603\n " "value_double: -0.000317367\n value_double: -0.000417829\n }\n " "cpu_exec_micros: 2\n total_cpu_exec_micros: 2\n run_count: 1\n " - "total_run_count: 1\n total_definition_count: 10\n output_bytes: " - "1280\n total_output_bytes: 1280\n}\ntotal_float_ops: " - "10440\ntotal_accelerator_exec_micros: 404\ntotal_cpu_exec_micros: " - "4541\ntotal_run_count: 6\ntotal_definition_count: 35\ntotal_peak_bytes: " - "9984\ntotal_residual_bytes: 1280\ntotal_output_bytes: 4864\n", + "total_run_count: 1\n total_definition_count: 10\n peak_bytes: 1280\n " + "residual_bytes: 1280\n output_bytes: 1280\n total_peak_bytes: 1280\n " + "total_residual_bytes: 1280\n total_output_bytes: " + "1280\n}\ntotal_float_ops: 10440\ntotal_accelerator_exec_micros: " + "404\ntotal_cpu_exec_micros: 4541\ntotal_run_count: " + "6\ntotal_definition_count: 35\ntotal_peak_bytes: " + "25856\ntotal_residual_bytes: 3840\ntotal_output_bytes: 4864\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); } diff --git a/tensorflow/core/profiler/tfprof_log.proto b/tensorflow/core/profiler/tfprof_log.proto index f92301133a..b49bdf64ac 100644 --- a/tensorflow/core/profiler/tfprof_log.proto +++ b/tensorflow/core/profiler/tfprof_log.proto @@ -124,9 +124,10 @@ message ExecProfile { int64 residual_bytes = 9; // Total bytes output by the op (not necessarily requested by the op). int64 output_bytes = 10; - // Total temporary bytes allocated and released by the op. + // NOTE: Please don't depend on the following 4 fields yet. Due to + // TensorFlow internal tracing issues, the numbers can be quite wrong. + // TODO(xpan): Fix the TensorFlow internal tracing. int64 host_temp_bytes = 11; - // Total persistent bytes (e.g. variable) allocated by the op. int64 host_persistent_bytes = 12; int64 accelerator_temp_bytes = 13; int64 accelerator_persistent_bytes = 14; diff --git a/tensorflow/docs_src/api_guides/python/reading_data.md b/tensorflow/docs_src/api_guides/python/reading_data.md index b3ebaa0f0a..4594887349 100644 --- a/tensorflow/docs_src/api_guides/python/reading_data.md +++ b/tensorflow/docs_src/api_guides/python/reading_data.md @@ -1,11 +1,11 @@ # Reading data Note: The preferred way to feed data into a tensorflow program is using the -@{$datasets$Datasets API}. +@{$datasets$`tf.data` API}. There are four methods of getting data into a TensorFlow program: -* `Dataset` API: Easily construct a complex input pipeline. (preferred method) +* `tf.data` API: Easily construct a complex input pipeline. (preferred method) * Feeding: Python code provides the data when running each step. * `QueueRunner`: a queue-based input pipeline reads the data from files at the beginning of a TensorFlow graph. @@ -14,26 +14,27 @@ There are four methods of getting data into a TensorFlow program: [TOC] -## Dataset API +## `tf.data` API See the @{$datasets$programmer's guide} for an in-depth explanation of -@{tf.data.Dataset}. The `Dataset` API allows you to extract and preprocess data -from different input/file formats, and apply transformations such as batch, -shuffle, and map to the dataset. This is an improved version of the old input -methods, feeding and `QueueRunner`. +@{tf.data.Dataset}. The `tf.data` API enables you to extract and preprocess data +from different input/file formats, and apply transformations such as batching, +shuffling, and mapping functions over the dataset. This is an improved version +of the old input methods---feeding and `QueueRunner`---which are described +below for historical purposes. ## Feeding +Warning: "Feeding" is the least efficient way to feed data into a TensorFlow +program and should only be used for small experiments and debugging. + TensorFlow's feed mechanism lets you inject data into any Tensor in a -computation graph. A python computation can thus feed data directly into the +computation graph. A Python computation can thus feed data directly into the graph. Supply feed data through the `feed_dict` argument to a run() or eval() call that initiates computation. -Warning: "Feeding" is the least efficient way to feed data into a tensorflow -program and should only be used for small experiments and debugging. - ```python with tf.Session(): input = tf.placeholder(tf.float32) @@ -55,6 +56,10 @@ and is described in the @{$mechanics$MNIST tutorial}. ## `QueueRunner` +Warning: This section discusses implementing input pipelines using the +queue-based APIs which can be cleanly replaced by the @{$datasets$`tf.data` +API}. + A typical queue-based pipeline for reading records from files has the following stages: 1. The list of filenames @@ -66,9 +71,6 @@ A typical queue-based pipeline for reading records from files has the following 7. *Optional* preprocessing 8. Example queue -Warning: This section discusses implementing input pipelines using the -queue-based APIs which can be cleanly replaced by the @{$datasets$Datasets API}. - ### Filenames, shuffling, and epoch limits For the list of filenames, use either a constant string Tensor (like @@ -499,7 +501,7 @@ You can have the train and eval in the same graph in the same process, and share their trained variables or layers. See @{$variables$the shared variables tutorial}. To support the single-graph approach -@{$programmers_guide/datasets$Datasets} also supplies +@{$programmers_guide/datasets$`tf.data`} also supplies @{$programmers_guide/datasets#creating_an_iterator$advanced iterator types} that that allow the user to change the input pipeline without rebuilding the graph or session. diff --git a/tensorflow/docs_src/get_started/custom_estimators.md b/tensorflow/docs_src/get_started/custom_estimators.md new file mode 100644 index 0000000000..e347aa6bd0 --- /dev/null +++ b/tensorflow/docs_src/get_started/custom_estimators.md @@ -0,0 +1,576 @@ + +# Creating Custom Estimators +This document introduces custom Estimators. In particular, this document +demonstrates how to create a custom @{tf.estimator.Estimator$Estimator} that +mimics the behavior of the pre-made Estimator +@{tf.estimator.DNNClassifier$`DNNClassifier`} in solving the Iris problem. See +the @{$get_started/estimator$Pre-Made Estimators chapter} for details. + +If you are feeling impatient, feel free to compare and contrast the following +full programs: + +* Iris implemented with the [pre-made DNNClassifier Estimator](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py). +* Iris implemented with a [custom Estimator](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py). + +## Pre-made vs. custom + +As the following figure shows, pre-made Estimators are subclasses of the +@{tf.estimator.Estimator} base class, while custom Estimators are an instance +of tf.estimator.Estimator: + +<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" + alt="Premade estimators are sub-classes of `Estimator`. Custom Estimators are usually (direct) instances of `Estimator`" + src="../images/custom_estimators/estimator_types.png"> +</div> +<div style="text-align: center"> +Pre-made and custom Estimators are all Estimators. +</div> + +Pre-made Estimators are fully baked. Sometimes though, you need more control +over an Estimator's behavior. That's where custom Estimators come in. You can +create a custom Estimator to do just about anything. If you want hidden layers +connected in some unusual fashion, write a custom Estimator. If you want to +calculate a unique +[metric](https://developers.google.com/machine-learning/glossary/#metric) +for your model, write a custom Estimator. Basically, if you want an Estimator +optimized for your specific problem, write a custom Estimator. + +A model function (or `model_fn`) implements the ML algorithm. The +only difference between working with pre-made Estimators and custom Estimators +is: + +* With pre-made Estimators, someone already wrote the model function for you. +* With custom Estimators, you must write the model function. + +Your model function could implement a wide range of algorithms, defining all +sorts of hidden layers and metrics. Like input functions, all model functions +must accept a standard group of input parameters and return a standard group of +output values. Just as input functions can leverage the Dataset API, model +functions can leverage the Layers API and the Metrics API. + +Let's see how to solve the Iris problem with a custom Estimator. A quick +reminder--here's the organization of the Iris model that we're trying to mimic: + +<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="height:260px" + alt="A diagram of the network architecture: Inputs, 2 hidden layers, and outputs" + src="../images/custom_estimators/full_network.png"> +</div> +<div style="text-align: center"> +Our implementation of Iris contains four features, two hidden layers, +and a logits output layer. +</div> + +## Write an Input function + +In our custom Estimator implementation, we'll reuse the input function we used +in the pre-made Estimator implementation. Namely: + +```python +def train_input_fn(features, labels, batch_size): + """An input function for training""" + # Convert the inputs to a Dataset. + dataset = tf.data.Dataset.from_tensor_slices((features, labels)) + + # Shuffle, repeat, and batch the examples. + dataset = dataset.shuffle(1000).repeat().batch(batch_size) + + # Return the read end of the pipeline. + return dataset.make_one_shot_iterator().get_next() +``` + +This input function builds an input pipeline that yields batches of +`(features, labels)` pairs, where `features` is a dictionary features. + +## Create feature columns + +<!-- TODO(markdaoust): link to feature_columns when it exists--> +As detailed in @{$get_started/estimator$Premade Estimators}, you must define +your model's feature columns to specify how the model should use each feature. +Whether working with pre-made Estimators or custom Estimators, you define +feature columns in the same fashion. + +The following code creates a simple `numeric_column` for each input feature, +indicating that the value of the input feature should be used directly as an +input to the model: + +```python +# Feature columns describe how to use the input. +my_feature_columns = [] +for key in train_x.keys(): + my_feature_columns.append(tf.feature_column.numeric_column(key=key)) +``` + +## Write a model function + +The model function we'll use has the following call signature: + +```python +def my_model_fn( + features, # This is batch_features from input_fn + labels, # This is batch_labels from input_fn + mode, # An instance of tf.estimator.ModeKeys + params): # Additional configuration +``` + +The first two arguments are the batches of features and labels returned from +the input function; that is, `features` and `labels` are the handles to the +data your model will use. The `mode` argument indicates whether the caller is +requesting training, predicting, or evaluation. + +The caller may pass `params` to an Estimator's constructor. The `params` passed +to the constructor become the `params` passed to `model_fn`. + +```python + # Build 2 hidden layer DNN with 10, 10 units respectively. + classifier = tf.estimator.Estimator( + model_fn=my_model, + params={ + 'feature_columns': my_feature_columns, + # Two hidden layers of 10 nodes each. + 'hidden_units': [10, 10], + # The model must choose between 3 classes. + 'n_classes': 3, + }) +``` + +To implement a typical model function, you must do the following: + +* (Define the model)[#define_the_model]. +* Specify additional calculations for each of + the [three different modes](#modes): + * [Predict](#predict) + * [Evaluate](#evaluate) + * [Train](#train) + +## Define the model + +The basic deep neural network model must define the following three sections: + +* An [input layer](https://developers.google.com/machine-learning/glossary/#input_layer) +* One or more [hidden layers](https://developers.google.com/machine-learning/glossary/#hidden_layer) +* An [output layer](https://developers.google.com/machine-learning/glossary/#output_layer) + +### Define the input layer + +Call @{tf.feature_column.input_layer} to convert your feature dictionary and +feature columns into input for your model. For example: + +```python + # Use `input_layer` to apply the feature columns. + net = tf.feature_column.input_layer(features, params['feature_columns']) +``` + +The preceding line applies the transformations defined by your feature columns, +creating the input layer of our model. + +<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="height:260px" + alt="A diagram of the input layer, in this case a 1:1 mapping from raw-inputs to features." + src="../images/custom_estimators/input_layer.png"> +</div> + + +### Hidden Layers + +If you are creating a deep neural network, you must define one or more hidden +layers. The Layers API provides a rich set of functions to define all types of +hidden layers, including convolutional, pooling, and dropout layers. For Iris, +we're simply going to call @{tf.layers.dense} to create hidden layers, with +dimensions defined by `params['hidden_layers']`. In a `dense` layer each node +is connected to every node in the preceding layer. Here's the relevant code: + +``` python + # Build the hidden layers, sized according to the 'hidden_units' param. + for units in params['hidden_units']: + net = tf.layers.dense(net, units=units, activation=tf.nn.relu) +``` +* The `units` parameter defines the number of output neurons in a given layer. +* The `activation` parameter defines the [activation function](https://developers.google.com/machine-learning/glossary/#a) — + [Relu](https://developers.google.com/machine-learning/glossary/#ReLU) in this + case. + +The variable `net` here signifies the current top layer of the network. During +the first iteration, `net` signifies the input layer. On each loop iteration +`tf.layers.dense` creates a new layer, which takes the previous layer as its +input. So, the loop uses `net` to pass the previously created layer as input +to the layer being created. + +After creating two hidden layers, our network looks as follows. For +simplicity, the figure only shows four hidden units in each layer. + +<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="height:260px" + alt="The input layer with two hidden layers added." + src="../images/custom_estimators/add_hidden_layer.png"> +</div> + +Note that @{tf.layers.dense} provides many additional capabilities, including +the ability to set a multitude of regularization parameters. For the sake of +simplicity, though, we're going to simply accept the default values of the +other parameters. + +### Output Layer + +We'll define the output layer by calling @{tf.layers.dense} yet again, this +time without an activation function: + +```python + # Compute logits (1 per class). + logits = tf.layers.dense(net, params['n_classes'], activation=None) +``` + +Here, `net` signifies the final hidden layer. Therefore, the full set of layers +is now connected as follows: + +<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="height:260px" + alt="A logit output layer connected to the top hidden layer" + src="../images/custom_estimators/add_logits.png"> +</div> +<div style="text-align: center"> +The final hidden layer feeds into the output layer. +</div> + +When defining an output layer, the `units` parameter specifies the number of +outputs. So, by setting `units` to `params['n_classes']`, the model produces +one output value per class. Each element of the output vector will contains the +score, or "logit", calculated to the associated class of Iris: Setosa, +Versicolor, or Virginica, respectively. + +Later on, these logits will be transformed into probabilities by the +@{tf.nn.softmax} function. + +## Implement training, evaluation, and prediction {modes} + +The final step in creating a model function is to write branching code that +implements prediction, evaluation, and training. + +The model function gets invoked whenever someone calls the Estimator's `train`, +`evaluate`, or `predict` methods. Recall that the signature for the model +function looks like this: + +``` python +def my_model_fn( + features, # This is batch_features from input_fn + labels, # This is batch_labels from input_fn + mode): # An instance of tf.estimator.ModeKeys, see below +``` + +Focus on that third argument, mode. As the following table shows, when someone +calls train, evaluate, or predict, the Estimator framework invokes your model +function with the mode parameter set as follows: + +| Estimator method | Estimator Mode | +|:---------------------------------|:------------------| +|@{tf.estimator.Estimator.train$`train()`} |@{tf.estimator.ModeKeys.TRAIN$`ModeKeys.TRAIN`} | +|@{tf.estimator.Estimator.evaluate$`evaluate()`} |@{tf.estimator.ModeKeys.EVAL$`ModeKeys.EVAL`} | +|@{tf.estimator.Estimator.predict$`predict()`}|@{tf.estimator.ModeKeys.PREDICT$`ModeKeys.PREDICT`} | + +For example, suppose you instantiate a custom Estimator to generate an object +named `classifier`. Then, you make the following call: + +``` python +classifier = tf.estimator.Estimator(...) +classifier.train(input_fn=lambda: my_input_fn(FILE_TRAIN, True, 500)) +``` +The Estimator framework then calls your model function with mode set to +`ModeKeys.TRAIN`. + +Your model function must provide code to handle all three of the mode values. +For each mode value, your code must return an instance of +`tf.estimator.EstimatorSpec`, which contains the information the caller +requires. Let's examine each mode. + +### Predict + +When the Estimator's `predict` method is called, the `model_fn` receives +`mode = ModeKeys.PREDICT`. In this case, the model function must return a +`tf.estimator.EstimatorSpec` containing the prediction. + +The model must have been trained prior to making a prediction. The trained model +is stored on disk in the `model_dir` directory established when you +instantiated the Estimator. + +The code to generate the prediction for this model looks as follows: + +```python +# Compute predictions. +predicted_classes = tf.argmax(logits, 1) +if mode == tf.estimator.ModeKeys.PREDICT: + predictions = { + 'class_ids': predicted_classes[:, tf.newaxis], + 'probabilities': tf.nn.softmax(logits), + 'logits': logits, + } + return tf.estimator.EstimatorSpec(mode, predictions=predictions) +``` +The prediction dictionary contains everything that your model returns when run +in prediction mode. + +<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="height:260px" + alt="Additional outputs added to the output layer." + src="../images/custom_estimators/full_network.png"> +</div> + +The `predictions` holds the following three key/value pairs: + +* `class_ids` holds the class id (0, 1, or 2) representing the model's + prediction of the most likely species for this example. +* `probabilities` holds the three probabilities (in this example, 0.02, 0.95, + and 0.03) +* `logit` holds the raw logit values (in this example, -1.3, 2.6, and -0.9) + +We return that dictionary to the caller via the `predictions` parameter of the +@{tf.estimator.EstimatorSpec}. The Estimator's +@{tf.estimator.Estimator.predict$`predict`} method will yield these +dictionaries. + +### Calculate the loss + +For both [training](#train) and [evaluation](#evaluate) we need to calculate the +model's loss. This is the +[objective](https://developers.google.com/machine-learning/glossary/#objective) +that will be optimized. + +Before we calculate loss, we we must first convert the labels from a list of +indexes `(0, 1, 2)` to a +[one-hot representation](https://developers.google.com/machine-learning/glossary/#one-hot_encoding) +by calling @{tf.one_hot}. Then, we can calculate the loss by calling +@{tf.losses.softmax_cross_entropy}. Here's the complete code: + + +```python + # Convert the labels to a one-hot tensor of shape (length of features, 3) + # and with a on-value of 1 for each one-hot vector of length 3. + onehot_labels = tf.one_hot(labels, 3, 1, 0) + + # Compute loss. + loss = tf.losses.softmax_cross_entropy( + onehot_labels=onehot_labels, logits=logits) +``` + +### Evaluate + +When the Estimator's `evaluate` method is called, the `model_fn` receives +`mode = ModeKeys.EVAL`. In this case, the model function must return a +`tf.estimator.EstimatorSpec` containing the model's loss and optionally one +or more metrics. + +Although returning metrics is optional, most custom Estimators do return at +least one metric. TensorFlow provides a Metrics module @{tf.metrics} to +calculate common metrics. For brevity's sake, we'll only return accuracy. The +@{tf.metrics.accuracy} function compares our predictions against the +true values, that is, against the labels provided by the input function. The +@{tf.metrics.accuracy} function requires the labels and predictions to have the +same shape. Here's the call to @{tf.metrics.accuracy}: + +``` python + # Compute evaluation metrics. + accuracy = tf.metrics.accuracy(labels=labels, + predictions=predicted_classes, + name='acc_op') +``` + +The @{tf.estimator.EstimatorSpec$`EstimatorSpec`} returned for evaluation +typically contains the following information: + +* `loss`, which is the model's loss +* `eval_metric_ops`, which is an optional dictionary of metrics. + +So, we'll create a dictionary containing our sole metric. If we had calculated +other metrics, we would have added them as additional key/value pairs to that +same dictionary. Then, we'll pass that dictionary in the `eval_metric_ops` +argument of `tf.estimator.EstimatorSpec`. Here's the code: + +```python + metrics = {'accuracy': accuracy} + tf.summary.scalar('accuracy', accuracy[1]) + + if mode == tf.estimator.ModeKeys.EVAL: + return tf.estimator.EstimatorSpec( + mode, loss=loss, eval_metric_ops=metrics) +``` + +The @{tf.summary.scalar} will make accuracy available to TensorBoard (more on +this later). + +### Train + +When the Estimator's `train` method is called, the `model_fn` is called +with `mode = ModeKeys.TRAIN`. In this case, the model function must return an +`EstimatorSpec` that contains the loss and a training operation. + +Building the training operation will require an optimizer. We will use +@{tf.train.AdagradOptimizer} because we're mimicking the `DNNClassifier`, which +also uses `Adagrad` by default. The `tf.train` package provides many other +optimizers—feel free to experiment with them. + +Here is the code that builds the optimizer: + +``` python + # Instantiate an optimizer. + optimizer = tf.train.AdagradOptimizer(learning_rate=0.1) +``` + +Next, we train the model using the optimizer's +@{tf.train.Optimizer.minimize$`minimize`} method on the loss we calculated +earlier. + +The `minimize` method also takes a `global_step` parameter. TensorFlow uses this +parameter to count the number of training steps that have been processed +(to know when to end a training run). Furthermore, the `global_step` is +essential for TensorBoard graphs to work correctly. Simply call +@{tf.train.get_global_step} and pass the result to the `global_step` +argument of `minimize`. + +Here's the code to train the model: + +``` python + # Train the model by establishing an objective, which is to + # minimize loss using that optimizer. + train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) +``` + +The @{tf.estimator.EstimatorSpec$`EstimatorSpec`} returned for training +must have the following fields set: + +* `loss`, which contains the value of the loss function. +* `train_op`, which executes a training step. + +Here's our code to call `EstimatorSpec`: + +```python + # Return training information. + return tf.estimator.EstimatorSpec( + mode=tf.estimator.ModeKeys.TRAIN, + loss=loss, + train_op=train_op) +``` + +The model function is now complete. + +## The custom Estimator + +Instantiate the custom Estimator through the Estimator base class as follows: + +```python + # Build 2 hidden layer DNN with 10, 10 units respectively. + classifier = tf.estimator.Estimator( + model_fn=my_model, + params={ + 'feature_columns': my_feature_columns, + # Two hidden layers of 10 nodes each. + 'hidden_units': [10, 10], + # The model must choose between 3 classes. + 'n_classes': 3, + }) +``` +Here the `params` dictionary serves the same purpose as the key-word +arguments of `DNNClassifier`; that is, the `params` dictionary lets you +configure your Estimator without modifying the code in the `model_fn`. + +The rest of the code to train, evaluate, and generate predictions using our +Estimator is the same as for the pre-made `DNNClassifier`. For example, the +following line will train the model: + +```python + # Train the Model. + classifier.train( + input_fn=lambda:train_input_fn(train_x, train_y, args.batch_size), + steps=args.train_steps) +``` + +## TensorBoard + +You can view training results for your custom Estimator in TensorBoard. To see +this reporting, start TensorBoard from your command line as follows: + +```bsh +# Replace PATH with the actual path passed as model_dir +tensorboard --logdir=PATH +``` + +Then, open TensorBoard by browsing to: [http://localhost:6006](http://localhost:6006) + +All the pre-made Estimators automatically log a lot of information to +TensorBoard. With custom Estimators, however, TensorBoard only provides one +default log (a graph of the loss) plus the information you explicitly tell +TensorBoard to log. For the custom Estimator you just created, TensorBoard +generates the following: + +<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="height:260px" + alt="Accuracy, steps/second, and loss 'scalar' graphs from tensorboard" + src="../images/custom_estimators/tensorboard.png"> +</div> +<div style="text-align: center"> +TensorBoard displays three graphs. +</div> + +In brief, here's what the three graphs tell you: + +* global_step/sec: A performance indicator showing how many batches (gradient + updates) we processed per second as the model trains. + +* loss: The loss reported. + +* accuracy: The accuracy is recorded by the following two lines: + + * `eval_metric_ops={'my_accuracy': accuracy})`, during evaluation. + * `tf.summary.scalar('accuracy', accuracy[1])`, during training. + +These tensorboard graphs are one of the main reasons it's important to pass a +`global_step` to your optimizer's `minimize` method. The model can't record +the x-coordinate for these graphs without it. + +Note the following in the `my_accuracy` and `loss` graphs: + +* The orange line represents training. +* The blue dot represents evaluation. + +During training, summaries (the orange line) are recorded periodically as +batches are processed, which is why it becomes a graph spanning x-axis range. + +By contrast, evaluation produces only a single point on the graph for each call +to `evaluate`. This point contains the average over the entire evaluation call. +This has no width on the graph as it is evaluated entirely from the model state +at a particular training step (from a single checkpoint). + +As suggested in the following figure, you may see and also selectively +disable/enable the reporting using the controls on the left side. + +<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="margin:auto;display:block;" + alt="Check-boxes allowing the user to select which runs are shown." + src="../images/custom_estimators/select_run.jpg"> +</div> +<div style="text-align: center"> +Enable or disable reporting. +</div> + + +## Summary + +Although pre-made Estimators can be an effective way to quickly create new +models, you will often need the additional flexibility that custom Estimators +provide. Fortunately, pre-made and custom Estimators follow the same +programming model. The only practical difference is that you must write a model +function for custom Estimators; everything else is the same. + +For more details, be sure to check out: + +* The +[official TensorFlow implementation of MNIST](https://github.com/tensorflow/models/tree/master/official/mnist), +which uses a custom estimator. + +* The TensorFlow +[official models repository](https://github.com/tensorflow/models/tree/master/official), +which contains more curated examples using custom estimators. + +* This [TensorBoard video](https://youtu.be/eBbEDRsCmv4), which introduces +TensorBoard. + + diff --git a/tensorflow/docs_src/get_started/feature_columns.md b/tensorflow/docs_src/get_started/feature_columns.md new file mode 100644 index 0000000000..f9537927b7 --- /dev/null +++ b/tensorflow/docs_src/get_started/feature_columns.md @@ -0,0 +1,570 @@ +# Feature Columns + +This document details feature columns. Think of **feature columns** as the +intermediaries between raw data and Estimators. Feature columns are very rich, +enabling you to transform a diverse range of raw data into formats that +Estimators can use, allowing easy experimentation. + +In @{$get_started/estimator$Premade Estimators}, we used the premade Estimator, +@{tf.estimator.DNNClassifier$`DNNClassifier`} to train a model to predict +different types of Iris flowers from four input features. That example created +only numerical feature columns (of type @{tf.feature_column.numeric_column}). +Although numerical feature columns model the lengths of petals and sepals +effectively, real world data sets contain all kinds of features, many of which +are non-numerical. + +<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/feature_columns/feature_cloud.jpg"> +</div> +<div style="text-align: center"> +Some real-world features (such as, longitude) are numerical, but many are not. +</div> + +## Input to a Deep Neural Network + +What kind of data can a deep neural network operate on? The answer +is, of course, numbers (for example, `tf.float32`). After all, every neuron in +a neural network performs multiplication and addition operations on weights and +input data. Real-life input data, however, often contains non-numerical +(categorical) data. For example, consider a `product_class` feature that can +contain the following three non-numerical values: + +* `kitchenware` +* `electronics` +* `sports` + +ML models generally represent categorical values as simple vectors in which a +1 represents the presence of a value and a 0 represents the absence of a value. +For example, when `product_class` is set to `sports`, an ML model would usually +represent `product_class` as `[0, 0, 1]`, meaning: + +* `0`: `kitchenware` is absent +* `0`: `electronics` is absent +* `1`: `sports` is present + +So, although raw data can be numerical or categorical, an ML model represents +all features as numbers. + +## Feature Columns + +As the following figure suggests, you specify the input to a model through the +`feature_columns` argument of an Estimator (`DNNClassifier` for Iris). +Feature Columns bridge input data (as returned by `input_fn`) with your model. + +<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/feature_columns/inputs_to_model_bridge.jpg"> +</div> +<div style="text-align: center"> +Feature columns bridge raw data with the data your model needs. +</div> + +To create feature columns, call functions from the +@{tf.feature_column} module. This document explains nine of the functions in +that module. As the following figure shows, all nine functions return either a +Categorical-Column or a Dense-Column object, except `bucketized_column`, which +inherits from both classes: + +<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/feature_columns/some_constructors.jpg"> +</div> +<div style="text-align: center"> +Feature column methods fall into two main categories and one hybrid category. +</div> + +Let's look at these functions in more detail. + +### Numeric column + +The Iris classifier calls the @{tf.feature_column.numeric_column} function for +all input features: + + * `SepalLength` + * `SepalWidth` + * `PetalLength` + * `PetalWidth` + +Although `tf.numeric_column` provides optional arguments, calling +`tf.numeric_column` without any arguments, as follows, is a fine way to specify +a numerical value with the default data type (`tf.float32`) as input to your +model: + +```python +# Defaults to a tf.float32 scalar. +numeric_feature_column = tf.feature_column.numeric_column(key="SepalLength") +``` + +To specify a non-default numerical data type, use the `dtype` argument. For +example: + +``` python +# Represent a tf.float64 scalar. +numeric_feature_column = tf.feature_column.numeric_column(key="SepalLength", + dtype=tf.float64) +``` + +By default, a numeric column creates a single value (scalar). Use the shape +argument to specify another shape. For example: + +<!--TODO(markdaoust) link to full example--> +```python +# Represent a 10-element vector in which each cell contains a tf.float32. +vector_feature_column = tf.feature_column.numeric_column(key="Bowling", + shape=10) + +# Represent a 10x5 matrix in which each cell contains a tf.float32. +matrix_feature_column = tf.feature_column.numeric_column(key="MyMatrix", + shape=[10,5]) +``` +### Bucketized column + +Often, you don't want to feed a number directly into the model, but instead +split its value into different categories based on numerical ranges. To do so, +create a @{tf.feature_column.bucketized_column$bucketized column}. For +example, consider raw data that represents the year a house was built. Instead +of representing that year as a scalar numeric column, we could split the year +into the following four buckets: + +<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/feature_columns/bucketized_column.jpg"> +</div> +<div style="text-align: center"> +Dividing year data into four buckets. +</div> + +The model will represent the buckets as follows: + +|Date Range |Represented as... | +|:----------|:-----------------| +|< 1960 | [1, 0, 0, 0] | +|>= 1960 but < 1980 | [0, 1, 0, 0] | +|>= 1980 but < 2000 | [0, 0, 1, 0] | +|> 2000 | [0, 0, 0, 1] | + +Why would you want to split a number—a perfectly valid input to your +model—into a categorical value? Well, notice that the categorization splits a +single input number into a four-element vector. Therefore, the model now can +learn _four individual weights_ rather than just one; four weights creates a +richer model than one weight. More importantly, bucketizing enables the model +to clearly distinguish between different year categories since only one of the +elements is set (1) and the other three elements are cleared (0). When we just +use a single number (a year) as input, the model can only learn a linear +relationship. So, bucketing provides the model with additional flexibility that +the model can use to learn. + +The following code demonstrates how to create a bucketized feature: + +<!--TODO(markdaoust) link to full example - housing price grid?--> +```python +# First, convert the raw input to a numeric column. +numeric_feature_column = tf.feature_column.numeric_column("Year") + +# Then, bucketize the numeric column on the years 1960, 1980, and 2000. +bucketized_feature_column = tf.feature_column.bucketized_column( + source_column = numeric_feature_column, + boundaries = [1960, 1980, 2000]) +``` +Note that specifying a _three_-element boundaries vector creates a +_four_-element bucketized vector. + + +### Categorical identity column + +**Categorical identity columns** can be seen as a special case of bucketized +columns. In traditional bucketized columns, each bucket represents a range of +values (for example, from 1960 to 1979). In a categorical identity column, each +bucket represents a single, unique integer. For example, let's say you want to +represent the integer range `[0, 4)`. That is, you want to represent the +integers 0, 1, 2, or 3. In this case, the categorical identity mapping looks +like this: + +<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/feature_columns/categorical_column_with_identity.jpg"> +</div> +<div style="text-align: center"> +A categorical identity column mapping. Note that this is a one-hot +encoding, not a binary numerical encoding. +</div> + +As with bucketized columns, a model can learn a separate weight for each class +in a categorical identity column. For example, instead of using a string to +represent the `product_class`, let's represent each class with a unique integer +value. That is: + +* `0="kitchenware"` +* `1="electronics"` +* `2="sport"` + +Call @{tf.feature_column.categorical_column_with_identity} to implement a +categorical identity column. For example: + +``` python +# Create categorical output for an integer feature named "my_feature_b", +# The values of my_feature_b must be >= 0 and < num_buckets +identity_feature_column = tf.feature_column.categorical_column_with_identity( + key='my_feature_b', + num_buckets=4) # Values [0, 4) + +# In order for the preceding call to work, the input_fn() must return +# a dictionary containing 'my_feature_b' as a key. Furthermore, the values +# assigned to 'my_feature_b' must belong to the set [0, 4). +def input_fn(): + ... + return ({ 'my_feature_a':[7, 9, 5, 2], 'my_feature_b':[3, 1, 2, 2] }, + [Label_values]) +``` + +### Categorical vocabulary column + +We cannot input strings directly to a model. Instead, we must first map strings +to numeric or categorical values. Categorical vocabulary columns provide a good +way to represent strings as a one-hot vector. For example: + +<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/feature_columns/categorical_column_with_vocabulary.jpg"> +</div> +<div style="text-align: center"> +Mapping string values to vocabulary columns. +</div> + +As you can see, categorical vocabulary columns are kind of an enum version of +categorical identity columns. TensorFlow provides two different functions to +create categorical vocabulary columns: + +* @{tf.feature_column.categorical_column_with_vocabulary_list} +* @{tf.feature_column.categorical_column_with_vocabulary_file} + +`categorical_column_with_vocabulary_list` maps each string to an integer based +on an explicit vocabulary list. For example: + +```python +# Given input "feature_name_from_input_fn" which is a string, +# create a categorical feature by mapping the input to one of +# the elements in the vocabulary list. +vocabulary_feature_column = + tf.feature_column.categorical_column_with_vocabulary_list( + key="a feature returned by input_fn()", + vocabulary_list=["kitchenware", "electronics", "sports"]) +``` + +The preceding function is pretty straightforward, but it has a significant +drawback. Namely, there's way too much typing when the vocabulary list is long. +For these cases, call +`tf.feature_column.categorical_column_with_vocabulary_file` instead, which lets +you place the vocabulary words in a separate file. For example: + +```python + +# Given input "feature_name_from_input_fn" which is a string, +# create a categorical feature to our model by mapping the input to one of +# the elements in the vocabulary file +vocabulary_feature_column = + tf.feature_column.categorical_column_with_vocabulary_file( + key="a feature returned by input_fn()", + vocabulary_file="product_class.txt", + vocabulary_size=3) +``` + +`product_class.txt` should contain one line for each vocabulary element. In our +case: + +```None +kitchenware +electronics +sports +``` + +### Hashed Column + +So far, we've worked with a naively small number of categories. For example, +our product_class example has only 3 categories. Often though, the number of +categories can be so big that it's not possible to have individual categories +for each vocabulary word or integer because that would consume too much memory. +For these cases, we can instead turn the question around and ask, "How many +categories am I willing to have for my input?" In fact, the +@{tf.feature_column.categorical_column_with_hash_bucket} function enables you +to specify the number of categories. For this type of feature column the model +calculates a hash value of the input, then puts it into one of +the `hash_bucket_size` categories using the modulo operator, as in the following +pseudocode: + +```python +# pseudocode +feature_id = hash(raw_feature) % hash_buckets_size +``` + +The code to create the `feature_column` might look something like this: + +``` python +hashed_feature_column = + tf.feature_column.categorical_column_with_hash_bucket( + key = "some_feature", + hash_buckets_size = 100) # The number of categories +``` +At this point, you might rightfully think: "This is crazy!" After all, we are +forcing the different input values to a smaller set of categories. This means +that two probably unrelated inputs will be mapped to the same +category, and consequently mean the same thing to the neural network. The +following figure illustrates this dilemma, showing that kitchenware and sports +both get assigned to category (hash bucket) 12: + +<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/feature_columns/hashed_column.jpg"> +</div> +<div style="text-align: center"> +Representing data with hash buckets. +</div> + +As with many counterintuitive phenomena in machine learning, it turns out that +hashing often works well in practice. That's because hash categories provide +the model with some separation. The model can use additional features to further +separate kitchenware from sports. + +### Crossed column + +Combining features into a single feature, better known as +[feature crosses](https://developers.google.com/machine-learning/glossary/#feature_cross), +enables the model to learn separate weights for each combination of +features. + +More concretely, suppose we want our model to calculate real estate prices in +Atlanta, GA. Real-estate prices within this city vary greatly depending on +location. Representing latitude and longitude as separate features isn't very +useful in identifying real-estate location dependencies; however, crossing +latitude and longitude into a single feature can pinpoint locations. Suppose we +represent Atlanta as a grid of 100x100 rectangular sections, identifying each +of the 10,000 sections by a feature cross of latitude and longitude. This +feature cross enables the model to train on pricing conditions related to each +individual section, which is a much stronger signal than latitude and longitude +alone. + +The following figure shows our plan, with the latitude & longitude values for +the corners of the city in red text: + +<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/feature_columns/Atlanta.jpg"> +</div> +<div style="text-align: center"> +Map of Atlanta. Imagine this map divided into 10,000 sections of +equal size. +</div> + +For the solution, we used a combination of the `bucketized_column` we looked at +earlier, with the @{tf.feature_column.crossed_column} function. + +<!--TODO(markdaoust) link to full example--> + +``` python +def make_dataset(latitude, longitude, labels): + assert latitude.shape == longitude.shape == labels.shape + + features = {'latitude': latitude.flatten(), + 'longitude': longitude.flatten()} + labels=labels.flatten() + + return tf.data.Dataset.from_tensor_slices((features, labels)) + + +# Bucketize the latitude and longitude usig the `edges` +latitude_bucket_fc = tf.feature_column.bucketized_column( + tf.feature_column.numeric_column('latitude'), + list(atlanta.latitude.edges)) + +longitude_bucket_fc = tf.feature_column.bucketized_column( + tf.feature_column.numeric_column('longitude'), + list(atlanta.longitude.edges)) + +# Cross the bucketized columns, using 5000 hash bins. +crossed_lat_lon_fc = tf.feature_column.crossed_column( + [latitude_bucket_fc, longitude_bucket_fc], 5000) + +fc = [ + latitude_bucket_fc, + longitude_bucket_fc, + crossed_lat_lon_fc] + +# Build and train the Estimator. +est = tf.estimator.LinearRegressor(fc, ...) +``` + +You may create a feature cross from either of the following: + +* Feature names; that is, names from the `dict` returned from `input_fn`. +* Any categorical column, except `categorical_column_with_hash_bucket` + (since `crossed_column` hashes the input). + +When the feature columns `latitude_bucket_fc` and `longitude_bucket_fc` are +crossed, TensorFlow will create `(latitude_fc, longitude_fc)` pairs for each +example. This would produce a full grid of possibilities as follows: + +``` None + (0,0), (0,1)... (0,99) + (1,0), (1,1)... (1,99) + ... ... ... +(99,0), (99,1)...(99, 99) +``` + +Except that a full grid would only be tractable for inputs with limited +vocabularies. Instead of building this, potentially huge, table of inputs, +the `crossed_column` only builds the number requested by the `hash_bucket_size` +argument. The feature column assigns an example to a index by running a hash +function on the tuple of inputs, followed by a modulo operation with +`hash_bucket_size`. + +As discussed earlier, performing the +hash and modulo function limits the number of categories, but can cause category +collisions; that is, multiple (latitude, longitude) feature crosses will end +up in the same hash bucket. In practice though, performing feature crosses +still adds significant value to the learning capability of your models. + +Somewhat counterintuitively, when creating feature crosses, you typically still +should include the original (uncrossed) features in your model (as in the +preceding code snippet). The independent latitude and longitude features help the +model distinguish between examples where a hash collision has occured in the +crossed feature. + +## Indicator and embedding columns + +Indicator columns and embedding columns never work on features directly, but +instead take categorical columns as input. + +When using an indicator column, we're telling TensorFlow to do exactly what +we've seen in our categorical product_class example. That is, an +**indicator column** treats each category as an element in a one-hot vector, +where the matching category has value 1 and the rest have 0s: + +<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/feature_columns/categorical_column_with_identity.jpg"> +</div> +<div style="text-align: center"> +Representing data in indicator columns. +</div> + +Here's how you create an indicator column by calling +@{tf.feature_column.indicator_column}: + +``` python +categorical_column = ... # Create any type of categorical column. + +# Represent the categorical column as an indicator column. +indicator_column = tf.feature_column.indicator_column(categorical_column) +``` + +Now, suppose instead of having just three possible classes, we have a million. +Or maybe a billion. For a number of reasons, as the number of categories grow +large, it becomes infeasible to train a neural network using indicator columns. + +We can use an embedding column to overcome this limitation. Instead of +representing the data as a one-hot vector of many dimensions, an +**embedding column** represents that data as a lower-dimensional, ordinary +vector in which each cell can contain any number, not just 0 or 1. By +permitting a richer palette of numbers for every cell, an embedding column +contains far fewer cells than an indicator column. + +Let's look at an example comparing indicator and embedding columns. Suppose our +input examples consists of different words from a limited palette of only 81 +words. Further suppose that the data set provides provides the following input +words in 4 separate examples: + +* `"dog"` +* `"spoon"` +* `"scissors"` +* `"guitar"` + +In that case, the following figure illustrates the processing path for +embedding columns or indicator columns. + +<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/feature_columns/embedding_vs_indicator.jpg"> +</div> +<div style="text-align: center"> +An embedding column stores categorical data in a lower-dimensional +vector than an indicator column. (We just placed random numbers into the +embedding vectors; training determines the actual numbers.) +</div> + +When an example is processed, one of the `categorical_column_with...` functions +maps the example string to a numerical categorical value. For example, a +function maps "spoon" to `[32]`. (The 32 comes from our imagination—the actual +values depend on the mapping function.) You may then represent these numerical +categorical values in either of the following two ways: + +* As an indicator column. A function converts each numeric categorical value + into an 81-element vector (because our palette consists of 81 words), placing + a 1 in the index of the categorical value (0, 32, 79, 80) and a 0 in all the + other positions. + +* As an embedding column. A function uses the numerical categorical values + `(0, 32, 79, 80)` as indices to a lookup table. Each slot in that lookup table + contains a 3-element vector. + +How do the values in the embeddings vectors magically get assigned? Actually, +the assignments happen during training. That is, the model learns the best way +to map your input numeric categorical values to the embeddings vector value in +order to solve your problem. Embedding columns increase your model's +capabilities, since an embeddings vector learns new relationships between +categories from the training data. + +Why is the embedding vector size 3 in our example? Well, the following "formula" +provides a general rule of thumb about the number of embedding dimensions: + +```python +embedding_dimensions = number_of_categories**0.25 +``` + +That is, the embedding vector dimension should be the 4th root of the number of +categories. Since our vocabulary size in this example is 81, the recommended +number of dimensions is 3: + +``` python +3 = 81**0.25 +``` +Note that this is just a general guideline; you can set the number of embedding +dimensions as you please. + +Call @{tf.feature_column.embedding_column} to create an `embedding_column` as +suggested by the following snippet: + +``` python +categorical_column = ... # Create any categorical column + +# Represent the categorical column as an embedding column. +# This means creating a one-hot vector with one element for each category. +embedding_column = tf.feature_column.embedding_column( + categorical_column=categorical_column, + dimension=dimension_of_embedding_vector) +``` + +@{$programmers_guide/embedding$Embeddings} is a significant topic within machine +learning. This information was just to get you started using them as feature +columns. + +## Passing feature columns to Estimators + +As the following list indicates, not all Estimators permit all types of +`feature_columns` argument(s): + +* @{tf.estimator.LinearClassifier$`LinearClassifier`} and + @{tf.estimator.LinearRegressor$`LinearRegressor`}: Accept all types of + feature column. +* @{tf.estimator.DNNClassifier$`DNNClassifier`} and + @{tf.estimator.DNNRegressor$`DNNRegressor`}: Only accept dense columns. Other + column types must be wrapped in either an `indicator_column` or + `embedding_column`. +* @{tf.estimator.DNNLinearCombinedClassifier$`DNNLinearCombinedClassifier`} and + @{tf.estimator.DNNLinearCombinedRegressor$`DNNLinearCombinedRegressor`}: + * The `linear_feature_columns` argument accepts any feature column type. + * The `dnn_feature_columns` argument only accepts dense columns. + +## Other Sources + +For more examples on feature columns, view the following: + +* The @{$wide_and_deep$Wide & Deep Tutorial} +* [Examples](https://github.com/tensorflow/models/tree/master/samples/cookbook/regression) + of DNNs and linear models that use feature columns. + +To learn more about embeddings, see the following: + +* [Deep Learning, NLP, and representations](http://colah.github.io/posts/2014-07-NLP-RNNs-Representations/) + (Chris Olah's blog) +* The TensorFlow [Embedding Projector](http://projector.tensorflow.org) diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 217f542caa..a49973d550 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -511,6 +511,87 @@ contracted dimensions of `lhs` and `rhs` must be of the same size. In practice, it can be used to perform dot products between vectors, vector/matrix multiplications or matrix/matrix multiplications. +## DotGeneral + +See also +[`ComputationBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h). + +<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b> + +| Arguments | Type | Semantics +| --------- | ----------------------- | --------------- +| `lhs` | `ComputationDataHandle` | array of type T +| `rhs` | `ComputationDataHandle` | array of type T +| `dimension_numbers` | `DotDimensionNumbers` | array of type T + +As Dot, but allows contracting and batch dimension numbers to be specified for +both the 'lhs' and 'rhs'. + +| DotDimensionNumbers Fields | Type | Semantics +| --------- | ----------------------- | --------------- +| 'lhs_contracting_dimensions' | repeated int64 | 'lhs' contracting dimension numbers | +| 'rhs_contracting_dimensions' | repeated int64 | 'rhs' contracting dimension numbers | +| 'lhs_batch_dimensions' | repeated int64 | 'lhs' batch dimension numbers | +| 'rhs_batch_dimensions' | repeated int64 | 'rhs' batch dimension numbers | + +DotGeneral performs the sum of products over contracting dimensions specified +in 'dimension_numbers'. + +Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need +to be the same, but must be listed in the same order in both +'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes. + +Example with contracting dimension numbers: + +``` +lhs = { {1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0} } + +rhs = { {1.0, 1.0, 1.0}, + {2.0, 2.0, 2.0} } + +DotDimensionNumbers dnums; +dnums.add_lhs_contracting_dimensions(1); +dnums.add_rhs_contracting_dimensions(1); + +DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, + {15.0, 30.0} } +``` + +Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same +dimension number, must be listed in the same order in both arrays, and must +have the same dimension sizes. + +Example with batch dimension numbers (batch size 2, 2x2 matrices): + +``` +lhs = { { {1.0, 2.0}, + {3.0, 4.0} }, + { {5.0, 6.0}, + {7.0, 8.0} } } + +rhs = { { {1.0, 0.0}, + {0.0, 1.0} }, + { {1.0, 0.0}, + {0.0, 1.0} } } + +DotDimensionNumbers dnums; +dnums.add_lhs_contracting_dimensions(2); +dnums.add_rhs_contracting_dimensions(1); +dnums.add_lhs_batch_dimensions(0); +dnums.add_rhs_batch_dimensions(0); + +DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0}, + {3.0, 4.0} }, + { {5.0, 6.0}, + {7.0, 8.0} } } +``` + +| Input | Output | Semantics | +| ----------------------------------- | ----------------- | ---------------- | +| [b0, m, k] `dot` [b0, k, n] | [b0, m, n] | batch matmul | +| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n] | batch matmul | + ## Element-wise binary arithmetic operations See also diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md index 9ced56f0f5..c54b399c3a 100644 --- a/tensorflow/docs_src/programmers_guide/datasets.md +++ b/tensorflow/docs_src/programmers_guide/datasets.md @@ -1,16 +1,16 @@ # Importing Data -The @{tf.data.Dataset$`Dataset`} API enables you to build complex input pipelines from +The `tf.data` API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training. The pipeline for a text model might involve extracting symbols from raw text data, converting them to embedding identifiers with a lookup -table, and batching together sequences of different lengths. The `Dataset` API +table, and batching together sequences of different lengths. The `tf.data` API makes it easy to deal with large amounts of data, different data formats, and complicated transformations. -The `Dataset` API introduces two new abstractions to TensorFlow: +The `tf.data` API introduces two new abstractions to TensorFlow: * A `tf.data.Dataset` represents a sequence of elements, in which each element contains one or more `Tensor` objects. For example, in an image @@ -121,7 +121,7 @@ dataset3 = dataset3.filter(lambda x, (y, z): ...) ### Creating an iterator Once you have built a `Dataset` to represent your input data, the next step is to -create an `Iterator` to access elements from that dataset. The `Dataset` API +create an `Iterator` to access elements from that dataset. The `tf.data` API currently supports the following iterators, in increasing level of sophistication: @@ -379,7 +379,7 @@ sess.run(iterator.initializer, feed_dict={features_placeholder: features, ### Consuming TFRecord data -The `Dataset` API supports a variety of file formats so that you can process +The `tf.data` API supports a variety of file formats so that you can process large datasets that do not fit in memory. For example, the TFRecord file format is a simple record-oriented binary format that many TensorFlow applications use for training data. The `tf.data.TFRecordDataset` class enables you to @@ -628,7 +628,7 @@ TODO(mrry): Add this section. ### Processing multiple epochs -The `Dataset` API offers two main ways to process multiple epochs of the same +The `tf.data` API offers two main ways to process multiple epochs of the same data. The simplest way to iterate over a dataset in multiple epochs is to use the @@ -693,7 +693,7 @@ dataset = dataset.repeat() The @{tf.train.MonitoredTrainingSession} API simplifies many aspects of running TensorFlow in a distributed setting. `MonitoredTrainingSession` uses the @{tf.errors.OutOfRangeError} to signal that training has completed, so to use it -with the `Dataset` API, we recommend using +with the `tf.data` API, we recommend using `Dataset.make_one_shot_iterator()`. For example: ```python diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index 79202a38d7..881a975e60 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -126,6 +126,10 @@ the Android NDK and SDK must be installed on your system. 2. The Android NDK is required to build the native (C/C++) TensorFlow code. The current recommended version is 14b, which may be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads). + + * NDK 16, the revision released in November 2017, is **incompatible** with + Bazel. See [here](https://github.com/tensorflow/tensorflow/issues/14918). + 3. The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android @@ -133,6 +137,10 @@ the Android NDK and SDK must be installed on your system. 23 is required to build the TF Android demo (though it will run on API >= 21 devices). + - The Android Studio SDK Manager's NDK installer will install the latest + revision of the NDK, which is **incompatible** with Bazel. You'll need + to download an older version manually, as (2) suggests. + ##### Edit WORKSPACE The Android entries in diff --git a/tensorflow/examples/how_tos/reading_data/convert_to_records.py b/tensorflow/examples/how_tos/reading_data/convert_to_records.py index a402eac053..c89e839563 100644 --- a/tensorflow/examples/how_tos/reading_data/convert_to_records.py +++ b/tensorflow/examples/how_tos/reading_data/convert_to_records.py @@ -55,12 +55,15 @@ def convert_to(data_set, name): with tf.python_io.TFRecordWriter(filename) as writer: for index in range(num_examples): image_raw = images[index].tostring() - example = tf.train.Example(features=tf.train.Features(feature={ - 'height': _int64_feature(rows), - 'width': _int64_feature(cols), - 'depth': _int64_feature(depth), - 'label': _int64_feature(int(labels[index])), - 'image_raw': _bytes_feature(image_raw)})) + example = tf.train.Example( + features=tf.train.Features( + feature={ + 'height': _int64_feature(rows), + 'width': _int64_feature(cols), + 'depth': _int64_feature(depth), + 'label': _int64_feature(int(labels[index])), + 'image_raw': _bytes_feature(image_raw) + })) writer.write(example.SerializeToString()) diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py index f46d5e59b4..f5bf04305a 100644 --- a/tensorflow/examples/speech_commands/train.py +++ b/tensorflow/examples/speech_commands/train.py @@ -156,7 +156,8 @@ def main(_): predicted_indices = tf.argmax(logits, 1) expected_indices = tf.argmax(ground_truth_input, 1) correct_prediction = tf.equal(predicted_indices, expected_indices) - confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices, num_classes=label_count) + confusion_matrix = tf.confusion_matrix( + expected_indices, predicted_indices, num_classes=label_count) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 46c600eab1..f200a8e00a 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -20,6 +20,24 @@ package tensorflow // // #include <stdlib.h> // #include <string.h> +// +// void TF_SetAttrShapeList_Helper(TF_OperationDescription* desc, +// const char* attr_name, +// const int64_t* flat_dims, +// const int* num_dims, +// int num_shapes) { +// const int64_t** dims = +// (const int64_t**)malloc(sizeof(const int64_t*) * num_shapes); +// for (int i = 0; i < num_shapes; i++) { +// dims[i] = flat_dims; +// if (num_dims[i] > 0) { +// // flat_dims will be NULL iff num_shapes is 0 or all elements in num_dims are <= 0. +// flat_dims += num_dims[i]; +// } +// } +// TF_SetAttrShapeList(desc, attr_name, dims, num_dims, num_shapes); +// free(dims); +// } import "C" import ( @@ -289,41 +307,37 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu return fmt.Errorf("bad value for attribute %q: %v", name, err) } case Shape: - ndims, dims := cshape(value) + ndims := C.int(value.NumDimensions()) var dimsp *C.int64_t if ndims > 0 { + dims := make([]C.int64_t, ndims) + for i, d := range value.dims { + dims[i] = C.int64_t(d) + } dimsp = &dims[0] } C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims) case []Shape: - ndims := make([]C.int, len(value)) - dims := make([][]C.int64_t, len(value)) - dimsp := make([]*C.int64_t, len(value)) - for i, s := range value { - ndims[i], dims[i] = cshape(s) - if ndims[i] > 0 { - dimsp[i] = &dims[i][0] - } - } - if len(value) > 0 { - C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value))) - } else { + if len(value) == 0 { C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0) + } else { + var flatDims []C.int64_t + ndims := make([]C.int, len(value)) + for i, s := range value { + nd := s.NumDimensions() + ndims[i] = C.int(nd) + for _, d := range s.dims { + flatDims = append(flatDims, C.int64_t(d)) + } + } + var flatDimsp *C.int64_t + if len(flatDims) > 0 { + flatDimsp = &flatDims[0] + } + C.TF_SetAttrShapeList_Helper(cdesc, cAttrName, flatDimsp, &ndims[0], C.int(len(value))) } default: return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) } return nil } - -func cshape(s Shape) (C.int, []C.int64_t) { - ndims := C.int(s.NumDimensions()) - if ndims < 0 { - return -1, nil - } - dims := make([]C.int64_t, ndims) - for i, s := range s.dims { - dims[i] = C.int64_t(s) - } - return ndims, dims -} diff --git a/tensorflow/go/op/op_test.go b/tensorflow/go/op/op_test.go index 2451ba3606..842dee9ffe 100644 --- a/tensorflow/go/op/op_test.go +++ b/tensorflow/go/op/op_test.go @@ -58,3 +58,76 @@ func TestAddOperationFailure(t *testing.T) { _ = resize.Shape() t.Errorf("resize.Shape() should have paniced since the underlying Operation was not created") } + +func TestShapeAttribute(t *testing.T) { + s := NewScope() + x := Placeholder(s.SubScope("x"), tf.Int32, PlaceholderShape(tf.MakeShape(1))) + y := Placeholder(s.SubScope("y"), tf.Int32, PlaceholderShape(tf.Shape{})) + z := Add(s, x, y) + graph, err := s.Finalize() + if err != nil { + t.Fatal(err) + } + sess, err := tf.NewSession(graph, nil) + if err != nil { + t.Fatal(err) + } + + value, err := tf.NewTensor([]int32{7}) + if err != nil { + t.Fatal(err) + } + feeds := map[tf.Output]*tf.Tensor{ + x: value, + y: value, + } + fetched, err := sess.Run(feeds, []tf.Output{z}, nil) + if err != nil { + t.Fatal(err) + } + if got, want := len(fetched), 1; got != want { + t.Fatalf("Fetched %d tensors, expected %d", got, want) + } + if got, want := fetched[0].Value().([]int32), []int32{14}; len(got) != len(want) || len(got) != 1 || got[0] != want[0] { + t.Fatalf("Got %v, want %v", got, want) + } +} + +func TestDataset(t *testing.T) { + var ( + s = NewScope() + + // The use of a non-scalar here is inspired by + // https://github.com/tensorflow/tensorflow/issues/14891 + c = Const(s, []int32{21718, 31415}) + types = []tf.DataType{c.DataType()} + shapes = []tf.Shape{c.Shape()} + dataset = TensorDataset(s, []tf.Output{c}, shapes) + + iterator = Iterator(s, "", "", types, shapes) + next = IteratorGetNext(s, iterator, types, shapes) + init = MakeIterator(s, dataset, iterator) + ) + graph, err := s.Finalize() + if err != nil { + t.Fatal(err) + } + sess, err := tf.NewSession(graph, nil) + if err != nil { + t.Fatal(err) + } + if _, err := sess.Run(nil, nil, []*tf.Operation{init}); err != nil { + t.Fatal(err) + } + results, err := sess.Run(nil, next, nil) + if err != nil { + t.Fatal(err) + } + got := results[0].Value().([]int32) + if len(got) != 2 || got[0] != 21718 || got[1] != 31415 { + t.Errorf("Got %v, want {21718, 31415}", got) + } + if _, err := sess.Run(nil, next, nil); err == nil { + t.Errorf("Expected sess.Run() to fail since the iterator should have reached the end of the dataset") + } +} diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index cd6f4bc1f0..2d25c04dc9 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -270,7 +270,7 @@ func typeOf(dt DataType, shape []int64) reflect.Type { } } if ret == nil { - panic(bug("DataType %v is not supported", dt)) + panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt)) } for range shape { ret = reflect.SliceOf(ret) diff --git a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java index beb3635585..a24150484e 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java @@ -352,7 +352,8 @@ public final class OperationBuilder { private static native void setAttrShape(long handle, String name, long[] shape, int numDims); - private static native void setAttrShapeList(long handle, String name, long[] shapes, int[] numDims); + private static native void setAttrShapeList( + long handle, String name, long[] shapes, int[] numDims); private static native void setAttrStringList(long handle, String name, Object[] value); } diff --git a/tensorflow/java/src/main/native/operation_builder_jni.cc b/tensorflow/java/src/main/native/operation_builder_jni.cc index 71a451ad13..55d214a7c4 100644 --- a/tensorflow/java/src/main/native/operation_builder_jni.cc +++ b/tensorflow/java/src/main/native/operation_builder_jni.cc @@ -275,15 +275,15 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShapeList( if (num_dims_length > 0) { const int shapes_length = env->GetArrayLength(shapes); cshapes.reset(new int64_t[shapes_length]); - cdims.reset(new int64_t* [num_dims_length]); + cdims.reset(new int64_t*[num_dims_length]); cnum_dims.reset(new int[num_dims_length]); jlong* shapes_elems = - (jlong*) env->GetPrimitiveArrayCritical(shapes, nullptr); + static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr)); std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3); env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT); int64_t* cshapes_ptr = cshapes.get(); jint* num_dims_elems = - (jint*) env->GetPrimitiveArrayCritical(num_dims, nullptr); + static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr)); for (int i = 0; i < num_dims_length; ++i) { cnum_dims[i] = static_cast<int>(num_dims_elems[i]); cdims[i] = cshapes_ptr; diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java index 2430816725..0a4a8cf4e3 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java @@ -151,10 +151,10 @@ public class OperationBuilderTest { @Test public void setAttrShapeList() { // Those shapes match tensors ones, so no exception is thrown - testSetAttrShapeList(new Shape[] { Shape.make(2, 2), Shape.make(2, 2, 2) }); + testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)}); try { // Those shapes do not match tensors ones, exception is thrown - testSetAttrShapeList(new Shape[] { Shape.make(2, 2), Shape.make(2, 2, 2, 2) }); + testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2, 2)}); fail("Shapes are incompatible and an exception was expected"); } catch (IllegalArgumentException e) { // expected @@ -189,20 +189,23 @@ public class OperationBuilderTest { } private static void testSetAttrShapeList(Shape[] shapes) { - try (Graph g = new Graph(); Session s = new Session(g)) { - int[][] matrix = new int[][] { { 0, 0 }, { 0, 0 } }; - Output<?> queue = g.opBuilder("FIFOQueue", "queue") - .setAttr("component_types", new DataType[] { DataType.INT32, DataType.INT32 }) - .setAttr("shapes", shapes) - .build() - .output(0); + try (Graph g = new Graph(); + Session s = new Session(g)) { + int[][] matrix = new int[][] {{0, 0}, {0, 0}}; + Output<?> queue = + g.opBuilder("FIFOQueue", "queue") + .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32}) + .setAttr("shapes", shapes) + .build() + .output(0); assertTrue(hasNode(g, "queue")); Output<Integer> c1 = TestUtil.constant(g, "const1", matrix); - Output<Integer> c2 = TestUtil.constant(g, "const2", new int[][][] { matrix, matrix }); - Operation enqueue = g.opBuilder("QueueEnqueue", "enqueue") - .addInput(queue) - .addInputList(new Output<?>[] { c1, c2 }) - .build(); + Output<Integer> c2 = TestUtil.constant(g, "const2", new int[][][] {matrix, matrix}); + Operation enqueue = + g.opBuilder("QueueEnqueue", "enqueue") + .addInput(queue) + .addInputList(new Output<?>[] {c1, c2}) + .build(); assertTrue(hasNode(g, "enqueue")); s.runner().addTarget(enqueue).run(); diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 23ad9bfa56..12d81c4383 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -268,6 +268,7 @@ cc_library( deps = [ ":ndarray_tensor_bridge", ":numpy_lib", + ":py_util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -309,6 +310,7 @@ cc_library( hdrs = ["lib/core/py_seq_tensor.h"], deps = [ ":numpy_lib", + ":py_util", ":safe_ptr", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -317,6 +319,17 @@ cc_library( ) cc_library( + name = "py_util", + srcs = ["lib/core/py_util.cc"], + hdrs = ["lib/core/py_util.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:script_ops_op_lib", + "//util/python:python_headers", + ], +) + +cc_library( name = "py_record_reader_lib", srcs = ["lib/io/py_record_reader.cc"], hdrs = ["lib/io/py_record_reader.h"], diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index f4b0271195..e4545d287b 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -28,6 +28,8 @@ import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import types_pb2 from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -1742,5 +1744,136 @@ class SessionTest(test_util.TensorFlowTestCase): self.runTestAddFunctionToSession(server.target) +class GraphMutationTest(test_util.TensorFlowTestCase): + + def testUpdateInputAfterRunning(self): + with ops.Graph().as_default() as g: + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + + with session.Session(graph=g) as sess: + self.assertAllEqual(3.0, sess.run(c)) + c.op._update_input(1, a) # pylint: disable=protected-access + with self.assertRaisesRegexp( + errors.FailedPreconditionError, + 'add.*was changed by updating input tensor after it was run'): + sess.run(c) + + # Check that running the graph with a new session is fine + with session.Session(graph=g) as sess2: + self.assertAllEqual(2.0, sess2.run(c)) + + def testSetDeviceAfterRunning(self): + with ops.Graph().as_default() as g: + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + + with session.Session(graph=g) as sess: + self.assertAllEqual(3.0, sess.run(c)) + c.op._set_device('/cpu:0') # pylint: disable=protected-access + with self.assertRaisesRegexp( + errors.FailedPreconditionError, + 'add.*was changed by setting device after it was run'): + sess.run(c) + + def testSetAttrAfterRunning(self): + with ops.Graph().as_default() as g: + a = constant_op.constant(1.0, dtype=dtypes.float32) + b = math_ops.cast(a, dtypes.float64) + + with session.Session(graph=g) as sess: + self.assertAllEqual(1.0, sess.run(b)) + b.op._set_attr('DstT', + attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT)) + with self.assertRaisesRegexp( + errors.FailedPreconditionError, + 'Cast.*was changed by setting attribute after it was run'): + sess.run(b) + + def testRunModifyRun(self): + with ops.Graph().as_default() as g: + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + + with session.Session(graph=g) as sess: + self.assertAllEqual(3.0, sess.run(c)) + + d = b + c + d.op._update_input(0, a) # pylint: disable=protected-access + self.assertAllEqual(3.0, sess.run(c)) + self.assertAllEqual(4.0, sess.run(d)) + + def testRunModifyRunTwoSessions(self): + with ops.Graph().as_default() as g: + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + + with session.Session(graph=g) as sess1: + with session.Session(graph=g) as sess2: + self.assertAllEqual(3.0, sess1.run(c)) + self.assertAllEqual(3.0, sess2.run(c)) + + d = b + c + d.op._update_input(0, a) # pylint: disable=protected-access + self.assertAllEqual(3.0, sess2.run(c)) + self.assertAllEqual(4.0, sess2.run(d)) + + d.op._update_input(0, b) # pylint: disable=protected-access + self.assertAllEqual(3.0, sess1.run(c)) + self.assertAllEqual(5.0, sess1.run(d)) + + with self.assertRaisesRegexp( + errors.FailedPreconditionError, + 'add.*was changed by updating input tensor after it was run'): + sess2.run(c) + + def testTwoSessionsOneRunBeforeModification(self): + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + + with session.Session(graph=g) as sess1: + with session.Session(graph=g) as sess2: + sess1.run(c) + + c.op._set_device('/cpu:0') # pylint: disable=protected-access + + with self.assertRaisesRegexp( + errors.FailedPreconditionError, + 'add.*was changed by setting device after it was run'): + sess1.run(c) + + # sess2 was not run before modification + self.assertAllEqual(3.0, sess2.run(c)) + + def testTwoSessionsBothRunBeforeModification(self): + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + + with session.Session(graph=g) as sess1: + with session.Session(graph=g) as sess2: + sess1.run(c) + sess2.run(c) + + c.op._set_device('/cpu:0') # pylint: disable=protected-access + + with self.assertRaisesRegexp( + errors.FailedPreconditionError, + 'add.*was changed by setting device after it was run'): + sess1.run(c) + + with self.assertRaisesRegexp( + errors.FailedPreconditionError, + 'add.*was changed by setting device after it was run'): + sess2.run(c) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 5fa1a7e8fc..d471a39b69 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -532,6 +532,49 @@ def TF_Reset(target, containers=None, config=None): %unignore TF_GraphGetTensorShapeHelper; %ignore TF_GraphGetTensorShape; +// We use TF_GraphSetTensorShape_wrapper instead of +// TF_GraphSetTensorShape +%ignore TF_GraphSetTensorShape; +%unignore tensorflow; +%unignore TF_GraphSetTensorShape_wrapper; + +// $input is a Python list of ints to a vector<int> for TF_GraphSetTensorShape_wrapper +%typemap(in) (const std::vector<int64_t>& dims) + (std::vector<int64_t> dims_local){ + if ($input != Py_None) { + if (!PyList_Check($input)) { + SWIG_exception_fail(SWIG_TypeError, tensorflow::strings::Printf( + "$symname: expected list but got %s ", Py_TYPE($input)->tp_name).c_str()); + } + size_t size = PyList_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* item = PyList_GetItem($input, i); + dims_local.push_back(PyInt_AsLong(item)); + } + $1 = &dims_local; + } else { + $1 = nullptr; + } +} + +// We use TF_GraphGetTensorShape_wrapper instead of +// TF_GraphGetTensorShape +%ignore TF_GraphGetTensorShape; +%unignore tensorflow; +%unignore TF_GraphGetTensorShape_wrapper; + +// Build a Python list of ints and return it. +%typemap(out) std::vector<int64_t> tensorflow::TF_GraphGetTensorShape_wrapper { + $result = PyList_New($1.size()); + if (!$result) { + SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); + } + + for (size_t i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM($result, i, PyInt_FromLong($1[i])); + } +} + %include "tensorflow/python/client/tf_session_helper.h" %unignoreall diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index ad982e5dd8..e4bf09a0ca 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -407,4 +407,23 @@ TF_Function* TF_GraphToFunction_wrapper( opts, description, out_status); } +void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, + const std::vector<int64_t>& dims, + bool unknown_shape, TF_Status* status) { + if (unknown_shape) { + TF_GraphSetTensorShape(graph, output, nullptr, -1, status); + return; + } + TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status); +} + +std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph, + TF_Output output, + int num_dims, + TF_Status* status) { + std::vector<int64_t> dims(num_dims); + TF_GraphGetTensorShape(graph, output, dims.data(), num_dims, status); + return dims; +} + } // namespace tensorflow diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 6ed08d3a58..bb7171db31 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -168,6 +168,20 @@ TF_Function* TF_GraphToFunction_wrapper( const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs, const NameVector& output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* out_status); + +// Set the shape of output. If unknown is true, `num_dims` must be set to +// -1 and `dims` is set to nullptr. +void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, + const std::vector<int64_t>& dims, + bool unknown_shape, TF_Status* status); + +// Return the shape of output. `num_dims` should be the output of +// TF_GraphGetTensorNumDims. If `num_dims = -1`, this should not be called. +std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph, + TF_Output output, + int num_dims, + TF_Status* status); + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index 05acfe4de7..695d3ef790 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -21,6 +21,7 @@ py_library( "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", + "//tensorflow/python:util", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", "//third_party/py/numpy", diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index dbe29c087a..927c6d5c02 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.util import deprecation class Dataset(object): @@ -219,6 +220,7 @@ class Dataset(object): return TensorSliceDataset(tensors) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") def from_sparse_tensor_slices(sparse_tensor): """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. @@ -1232,13 +1234,40 @@ class ShuffleDataset(Dataset): input_dataset, buffer_size, seed=None, - reshuffle_each_iteration=None): - """See `Dataset.shuffle()` for details.""" + reshuffle_each_iteration=None, + seed2=None): + """Randomly shuffles the elements of this dataset. + + Args: + input_dataset: The input dataset. + buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the + number of elements from this dataset from which the new + dataset will sample. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + reshuffle_each_iteration: (Optional.) A boolean, which if true indicates + that the dataset should be pseudorandomly reshuffled each time it is + iterated over. (Defaults to `True`.) + seed2: (Optional.) A `tf.int64` scalar `tf.Tensor` used to avoid seed + collision. Users should generally not need to specify this. This is + supposed to be used when both the seeds for the Dataset op need to be + manually specified. If not None, seed must also be non-None. + + Returns: + A `Dataset`. + + Raises: + ValueError: if invalid arguments are provided. + """ super(ShuffleDataset, self).__init__() self._input_dataset = input_dataset self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") - seed, seed2 = random_seed.get_seed(seed) + if seed2 is None: + seed, seed2 = random_seed.get_seed(seed) + elif seed is None: + raise ValueError("seed must be non-None if seed2 is non-None.") if seed is None: self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") else: diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index bd7ab3d34f..2455395635 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -379,9 +379,9 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): if check_types and isinstance(shallow_tree, dict): if set(input_tree) != set(shallow_tree): raise ValueError( - "The two structures don't have the same keys. Input " - "structure has keys %s, while shallow structure has keys %s." - % (list(_six.iterkeys(input_tree)), + "The two structures don't have the same keys. Input " + "structure has keys %s, while shallow structure has keys %s." % + (list(_six.iterkeys(input_tree)), list(_six.iterkeys(shallow_tree)))) input_tree = list(_six.iteritems(input_tree)) shallow_tree = list(_six.iteritems(shallow_tree)) diff --git a/tensorflow/python/data/util/nest_test.py b/tensorflow/python/data/util/nest_test.py index 8c84d9d1df..90dd7dfe77 100644 --- a/tensorflow/python/data/util/nest_test.py +++ b/tensorflow/python/data/util/nest_test.py @@ -271,8 +271,9 @@ class NestTest(test.TestCase): inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} expected_message = ( - "The two structures don't have the same keys. Input " - "structure has keys \['c'\], while shallow structure has keys \['d'\].") + r"The two structures don't have the same keys. Input " + r"structure has keys \['c'\], while shallow structure has " + r"keys \['d'\].") with self.assertRaisesRegexp(ValueError, expected_message): nest.assert_shallow_structure(inp_ab2, inp_ab1) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 0144f3b1e5..dc1142705a 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -540,7 +540,7 @@ def _ensure_unique_tensor_objects(parameter_positions, args): if i in parameter_positions: tid = ops.tensor_id(t) if tid in s: - args[i] = args[i]._dup() # pylint: disable=protected-access + args[i] = gen_array_ops.identity(args[i]) else: s.add(tid) return args diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 92f4e15c05..415416cfae 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -288,6 +288,21 @@ class Context(object): self._initialize_handle_and_devices() return self._num_gpus + def add_function(self, fn): + """Add a function definition to the context. + + Once added, the function (identified by its name) can be executed like any + other operation. + + Args: + fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). + """ + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.TFE_ContextAddFunction( + self._handle, # pylint: disable=protected-access + fn, + status) + def add_function_def(self, fdef): """Add a function definition to the context. diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 9bcd9c23c7..cadabb3a24 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -25,15 +25,19 @@ import threading import numpy as np +from tensorflow.core.framework import function_pb2 +from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager.graph_only_ops import graph_placeholder +from tensorflow.python.framework import c_api_util from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import graph_to_function_def +from tensorflow.python.framework import dtypes as dtypes_module +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -47,26 +51,41 @@ _scoped_captures = threading.local() _scoped_captures.tensors = None -def make_function_def(graph, operations, inputs, outputs): - """Makes function def where accesses to resources are serialized.""" - last_op_using_resource_tensor = {} - - # TODO(apassos) probably control flow has to be handled delicately here as in - # if a resource is accessed inside a control flow context we need the control - # dependency to point to something outside the context which is guaranteed to - # happen after the access. - # - # TODO(apassos) this should do some form of alias analysis as ops which - # forward the resources such as Identity and Switch can cause serialization to - # fail. - for op in operations: - for t in op.inputs: - if t.dtype == dtypes.resource: - if t.name in last_op_using_resource_tensor: - op._add_control_input(last_op_using_resource_tensor[t.name]) # pylint: disable=protected-access - last_op_using_resource_tensor[t.name] = op - return graph_to_function_def.graph_to_function_def( - graph, operations, inputs, outputs) +def make_function_def(name, graph, operations, inputs, outputs): + """Makes FunctionDef proto and defined function. + + Args: + name: the function name + graph: the graph from which to build the function + operations: the operations in the function body + inputs: tensors to be used as function arguments + outputs: tensors to be returned from the function + + Returns: + fdef: a FunctionDef protocol buffer for the function + fn: a wrapped TF_Function for the function + """ + with errors.raise_exception_on_not_ok_status() as status: + fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( + graph._c_graph, # pylint: disable=protected-access + compat.as_str(name), + False, + [o._c_op for o in operations], # pylint: disable=protected-access + [t._as_tf_output() for t in inputs], # pylint: disable=protected-access + [t._as_tf_output() for t in outputs], # pylint: disable=protected-access + [], + None, + compat.as_str(""), + status) + # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature, + # but also in general it's nice not to depend on it. + with c_api_util.tf_buffer() as buffer_: + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) + proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) + fdef = function_pb2.FunctionDef() + fdef.ParseFromString(compat.as_bytes(proto_data)) + return fdef, fn @contextlib.contextmanager @@ -85,7 +104,7 @@ def capture_value(tensor_map, value, dtype, name): if captured_value is None: captured_value = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) - if captured_value.dtype == dtypes.resource: + if captured_value.dtype == dtypes_module.resource: captured_value._handle_data = value._handle_data # pylint: disable=protected-access tensor_map[ops.tensor_id(value)] = (value, captured_value) else: @@ -120,11 +139,23 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): class CapturingGraph(ops.Graph): + """Graph used when constructing eager functions.""" def __init__(self, captures): super(CapturingGraph, self).__init__() self._building_function = True self.captures = captures + # Map from resource tensor name to last op (in program order) which uses + # this tensor. Used to enforce that execution order matches program order + # for resource tensors. + self._last_op_using_resource_tensor = {} + + # TODO(apassos) remove once the C API is used by default. + def _use_c_api_hack(self): + return True + + def clear_resource_control_flow_state(self): + self._last_op_using_resource_tensor = {} def create_op( self, @@ -137,12 +168,31 @@ class CapturingGraph(ops.Graph): op_def=None, compute_shapes=True, compute_device=True): + # TODO(apassos) probably control flow has to be handled delicately here as + # in if a resource is accessed inside a control flow context we need the + # control dependency to point to something outside the context which is + # guaranteed to happen after the access. + # + # TODO(apassos) this should do some form of alias analysis as ops which + # forward the resources such as Identity and Switch can cause serialization + # to fail. + resource_inputs = set() + control_inputs = set() for i, inp in enumerate(inputs): if inp.graph is not self: inputs[i] = capture_value(self.captures, inp, inp.dtype, inp.op.name) - return super(CapturingGraph, self).create_op( - op_type, inputs, dtypes, input_types, name, attrs, op_def, - compute_shapes, compute_device) + inp = inputs[i] + if inp.dtype == dtypes_module.resource: + if inp.name in self._last_op_using_resource_tensor: + control_inputs.add(self._last_op_using_resource_tensor[inp.name]) + resource_inputs.add(inp.name) + with self.control_dependencies(list(control_inputs)): + op = super(CapturingGraph, self).create_op( + op_type, inputs, dtypes, input_types, name, attrs, op_def, + compute_shapes, compute_device) + for name in resource_inputs: + self._last_op_using_resource_tensor[name] = op + return op # TODO(apassos): it'd be really nice if we could scope this registration. @@ -196,14 +246,20 @@ def _inference_name(n): return "__inference_%s_%s" % (n, ops.uid()) +# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction +# so it doesn't have the definition-generating logic and is just a container for +# an already-defined function. class _DefinedFunction(object): """Mocks the interface of tf _DefinedFunction.""" - def __init__(self, fdef): + def __init__(self, fdef, fn): self.definition = fdef self.name = fdef.signature.name + self.signature = fdef.signature self.grad_func_name = None self.python_grad_func = None + self._c_func = fn + self._grad_func = None def _map_sequence_obj_to_idx(sequence): @@ -239,6 +295,7 @@ class GraphModeFunction(object): input_placeholders, extra_inputs, fdef, + fn, graph, operations, func_outputs, @@ -252,7 +309,7 @@ class GraphModeFunction(object): self._graph = graph self._has_backprop = False self._func_name = fdef.signature.name - self._fdef = _DefinedFunction(fdef) + self._fdef = _DefinedFunction(fdef, fn) self._num_outputs = len(fdef.signature.output_arg) self._ops = operations self._func_outputs = func_outputs @@ -272,38 +329,45 @@ class GraphModeFunction(object): with self._graph.as_default(), context.graph_mode(): c = _CapturingContext() with c: - filtered_outputs = [ - x for x in self._returns if x is not None - ] + filtered_outputs = [x for x in self._returns if x is not None] self._out_grad_placeholders = [ - graph_placeholder(x.dtype, x.shape) for x in filtered_outputs - ] + graph_placeholder(x.dtype, x.shape) for x in filtered_outputs] in_gradients = gradients_impl.gradients( filtered_outputs, self._input_placeholders, grad_ys=self._out_grad_placeholders) - shapes = [x.shape for x in in_gradients if x is not None] + shapes = tuple(x.shape for x in in_gradients if x is not None) captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) - forward_function_def = make_function_def( - self._graph, self._ops, self._input_placeholders, + forward_name = _forward_name(self._func_name) + forward_function_def, forward_fn = make_function_def( + forward_name, self._graph, self._ops, self._input_placeholders, filtered_outputs + captures) - self._forward_fdef = _DefinedFunction(forward_function_def) - _register_with_name(_forward_name(self._func_name), forward_function_def) - backward_outputs = [x for x in in_gradients if x is not None] + self._forward_fdef = _DefinedFunction(forward_function_def, forward_fn) + _register(forward_fn) + backward_outputs = tuple(x for x in in_gradients if x is not None) all_inputs = self._out_grad_placeholders + captures - backward_function_def = make_function_def( - self._graph, [x.op for x in self._out_grad_placeholders - ] + list(sorted(c.known_ops, key=lambda x: x.name)), + # Excluding input ops from the body as we do not intend to execute these + # operations when the function is executed. + all_ignored_ops = frozenset(x.op for x in all_inputs) + # Enforce a deterministic order of operations in the generated graph. This + # means rerunning the function-defining code will always define the same + # function, which is useful if we serialize this etc. + fdef_ops = tuple(x for x in sorted(c.known_ops, key=lambda x: x.name) + if x not in all_ignored_ops) + bname = _backward_name(self._func_name) + backward_function_def, backward_fn = make_function_def( + bname, self._graph, fdef_ops, all_inputs, backward_outputs) - _register_with_name(_backward_name(self._func_name), backward_function_def) + _register(backward_fn) self._backward_function = GraphModeFunction( - all_inputs, [], backward_function_def, self._graph, c.known_ops, - in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes) + all_inputs, [], backward_function_def, backward_fn, self._graph, + c.known_ops, in_gradients, _map_sequence_obj_to_idx(backward_outputs), + shapes) def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" all_args = args + self._extra_inputs - signature = self._forward_fdef.definition.signature + signature = self._forward_fdef.signature ctx = context.context() if ctx.in_graph_mode(): g = ops.get_default_graph() @@ -314,7 +378,7 @@ class GraphModeFunction(object): return ops.internal_convert_to_tensor(x, ctx=ctx) op = g.create_op( signature.name, [make_tensor(x) for x in all_args], - [dtypes.DType(x.type) for x in signature.output_arg], + tuple(dtypes_module.DType(x.type) for x in signature.output_arg), op_def=signature, name="FunctionCall", compute_shapes=False) @@ -350,11 +414,8 @@ class GraphModeFunction(object): if v._trainable: # pylint: disable=protected-access tape.watch_variable(v) - tensor_inputs = [ - x for x in nest.flatten(args) - if isinstance(x, ops.Tensor) - ] - + tensor_inputs = [x for x in nest.flatten(args) + if isinstance(x, ops.Tensor)] if tape.should_record(tensor_inputs) or tape.should_record( self._extra_inputs): if not self._has_backprop: @@ -373,7 +434,7 @@ class GraphModeFunction(object): args = list(tensor_inputs) + self._extra_inputs op = g.create_op( signature.name, [ops.convert_to_tensor(x) for x in args], - [dtypes.DType(x.type) for x in signature.output_arg], + tuple(dtypes_module.DType(x.type) for x in signature.output_arg), op_def=signature, name="FunctionCall", compute_shapes=False) @@ -458,29 +519,32 @@ def _defun_internal(name, func, args, kwds): extra_inputs = [] extra_placeholders = [] outputs_list = nest.flatten(func_outputs) - output_shapes = [x.shape for x in outputs_list if x is not None] + output_shapes = tuple(x.shape for x in outputs_list if x is not None) - flat_inputs = [ - x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor) - ] + flat_inputs = [x for x in nest.flatten(func_inputs) + if isinstance(x, ops.Tensor)] all_inputs = flat_inputs + list(extra_placeholders) - + all_ignored_ops = frozenset(x.op for x in all_inputs) func_def_outputs = [x for x in outputs_list if x is not None] - inference_function_def = make_function_def( - tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs) + fname = _inference_name(name) + operations = tuple(x for x in tmp_graph.get_operations() + if x not in all_ignored_ops) + inference_function_def, fn = make_function_def( + fname, tmp_graph, operations, all_inputs, func_def_outputs) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? - _register_with_name(f.name, f.definition) - _register_with_name(_inference_name(name), inference_function_def) + _register(f._c_func) # pylint: disable=protected-access + _register(fn) return GraphModeFunction( all_inputs, extra_inputs, inference_function_def, + fn, tmp_graph, - tmp_graph.get_operations(), + operations, func_outputs, _map_sequence_obj_to_idx(func_def_outputs), output_shapes, @@ -506,10 +570,9 @@ def _cache_key(x): return x -def _register_with_name(name, fdef): - """Registers the function `fdef` with the name `name`.""" - fdef.signature.name = name - context.context().add_function_def(fdef) +def _register(fn): + """Registers the function `fn`.""" + context.context().add_function(fn) # TODO(apassos): better error messages for non-hashable arguments. diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 837a75c808..3da100d800 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -296,6 +296,7 @@ def _graph_callable_internal(func, shape_and_dtypes): # Call the function again, now replacing usages of variables with # placeholders. This assumes the variable capturing scope created above # knows about all variables. + tmp_graph.clear_resource_control_flow_state() with variable_captures.capturing_scope(), function.capture_tensors( captures): captured_outputs = func(*func_inputs) @@ -317,7 +318,9 @@ def _graph_callable_internal(func, shape_and_dtypes): placeholder_inputs = flat_inputs+ list(extra_placeholders) func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)] - initializer_function_def = function.make_function_def( + initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access + initializer_function_def, initializer_fn = function.make_function_def( + initialization_name, tmp_graph, initializing_operations, placeholder_inputs, @@ -326,13 +329,13 @@ def _graph_callable_internal(func, shape_and_dtypes): # Also, what about the gradient registry of these functions? Those need to be # addressed as well. for f in tmp_graph._functions.values(): # pylint: disable=protected-access - function._register_with_name(f.name, f.definition) # pylint: disable=protected-access - function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access - initializer_function_def) + function._register(f._c_func) # pylint: disable=protected-access + function._register(initializer_fn) # pylint: disable=protected-access initializer_function = function.GraphModeFunction( placeholder_inputs, extra_inputs, initializer_function_def, + initializer_fn, tmp_graph, initializing_operations, func_outputs, @@ -341,18 +344,20 @@ def _graph_callable_internal(func, shape_and_dtypes): capture_func_def_outputs = [ x for x in captured_outlist if isinstance(x, tf_ops.Tensor)] - captured_function_def = function.make_function_def( + captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access + captured_function_def, capturing_fn = function.make_function_def( + captured_function_name, tmp_graph, capturing_operations, placeholder_inputs, capture_func_def_outputs) - function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access - captured_function_def) + function._register(capturing_fn) # pylint: disable=protected-access captured_function = function.GraphModeFunction( placeholder_inputs, extra_inputs, captured_function_def, + capturing_fn, tmp_graph, capturing_operations, captured_outputs, diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py index 548e16a909..b9e6ca2a93 100644 --- a/tensorflow/python/eager/graph_callable_test.py +++ b/tensorflow/python/eager/graph_callable_test.py @@ -152,7 +152,6 @@ class GraphCallableTest(test.TestCase): self.assertAllEqual(5, f(constant_op.constant(2))) def testNestedFunction(self): - # TensorFlow function (which is what would be used in TensorFlow graph # construction). @function.Defun(dtypes.int32, dtypes.int32) diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index ce823cb567..b52d71dc6c 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -531,12 +531,9 @@ static PyTypeObject TFE_Py_Tape_Type = { // xcode 7 doesn't define thread_local, so for compatibility we implement our // own. TODO(apassos) remove once we can deprecate xcode 7. #ifndef __APPLE__ -thread_local std::vector<TFE_Py_Tape*>* tape_stack = nullptr; std::vector<TFE_Py_Tape*>* GetTapeStack() { - if (tape_stack == nullptr) { - tape_stack = new std::vector<TFE_Py_Tape*>; - } - return tape_stack; + thread_local std::vector<TFE_Py_Tape*> tape_stack; + return &tape_stack; } #else static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED); diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 03f386e9cf..e062e1fbfe 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -215,6 +215,7 @@ py_test( srcs_version = "PY2AND3", tags = [ "no_pip", + "noasan", # test flakily times out in asan mode. "notsan", # b/67510291 ], deps = [ @@ -433,6 +434,7 @@ py_library( "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/data", "//tensorflow/python/saved_model:builder", "//tensorflow/python/saved_model:tag_constants", "//third_party/py/numpy", diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index f267f4a54e..63103ef4c1 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -30,6 +30,7 @@ from google.protobuf import message from tensorflow.core.framework import summary_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as tf_session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config @@ -416,7 +417,7 @@ class Estimator(object): with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) self._create_and_assert_global_step(g) - features = self._get_features_from_input_fn( + features, input_hooks = self._get_features_from_input_fn( input_fn, model_fn_lib.ModeKeys.PREDICT) estimator_spec = self._call_model_fn( features, None, model_fn_lib.ModeKeys.PREDICT, self.config) @@ -426,7 +427,7 @@ class Estimator(object): checkpoint_filename_with_path=checkpoint_path, scaffold=estimator_spec.scaffold, config=self._session_config), - hooks=hooks) as mon_sess: + hooks=input_hooks + hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions) if not isinstance(predictions, dict): @@ -582,6 +583,11 @@ class Estimator(object): def _get_features_from_input_fn(self, input_fn, mode): """Extracts the `features` from return values of `input_fn`.""" result = self._call_input_fn(input_fn, mode) + input_hooks = [] + if isinstance(result, dataset_ops.Dataset): + iterator = result.make_initializable_iterator() + input_hooks.append(_DatasetInitializerHook(iterator)) + result = iterator.get_next() if isinstance(result, (list, tuple)): # Unconditionally drop the label (the second element of result). result = result[0] @@ -590,16 +596,22 @@ class Estimator(object): logging.warning('Input graph does not use tf.data.Dataset or contain a ' 'QueueRunner. That means predict yields forever. ' 'This is probably a mistake.') - return result + return result, input_hooks def _get_features_and_labels_from_input_fn(self, input_fn, mode): + """Extracts the `features` and labels from return values of `input_fn`.""" result = self._call_input_fn(input_fn, mode) + input_hooks = [] + if isinstance(result, dataset_ops.Dataset): + iterator = result.make_initializable_iterator() + input_hooks.append(_DatasetInitializerHook(iterator)) + result = iterator.get_next() if isinstance(result, (list, tuple)): if len(result) != 2: raise ValueError( 'input_fn should return (feautures, labels) as a len 2 tuple.') - return result - return result, None + return result[0], result[1], input_hooks + return result, None, input_hooks def _extract_batch_length(self, preds_evaluated): """Extracts batch length of predictions.""" @@ -723,8 +735,10 @@ class Estimator(object): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) training_util._get_or_create_global_step_read() # pylint: disable=protected-access - features, labels = self._get_features_and_labels_from_input_fn( - input_fn, model_fn_lib.ModeKeys.TRAIN) + features, labels, input_hooks = ( + self._get_features_and_labels_from_input_fn( + input_fn, model_fn_lib.ModeKeys.TRAIN)) + worker_hooks.extend(input_hooks) estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) # Check if the user created a loss summary, and add one if they didn't. @@ -822,8 +836,9 @@ class Estimator(object): with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) - features, labels = self._get_features_and_labels_from_input_fn( - input_fn, model_fn_lib.ModeKeys.EVAL) + features, labels, input_hooks = ( + self._get_features_and_labels_from_input_fn( + input_fn, model_fn_lib.ModeKeys.EVAL)) estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL, self.config) @@ -844,7 +859,8 @@ class Estimator(object): 'already defines a default metric with the same name.') eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor - all_hooks = list(hooks or []) + all_hooks = list(input_hooks) + all_hooks.extend(hooks) all_hooks.extend(list(estimator_spec.evaluation_hooks or [])) eval_results = evaluation._evaluate_once( # pylint: disable=protected-access @@ -1039,3 +1055,16 @@ def _has_dataset_or_queue_runner(maybe_tensor): # Now, check queue. return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS) + + +class _DatasetInitializerHook(training.SessionRunHook): + + def __init__(self, iterator): + self._iterator = iterator + + def begin(self): + self._initializer = self._iterator.initializer + + def after_create_session(self, session, coord): + del coord + session.run(self._initializer) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index c1b773b8c4..db64fbc9cc 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -913,6 +913,80 @@ class EstimatorGetVariablesTest(test.TestCase): self.assertEqual(3., est.get_variable_value('three')) +class EstimatorDatasetIntegrationTest(test.TestCase): + """Tests dataset integration.""" + + def test_returned_by_input_fn(self): + + def _input_fn(): + return dataset_ops.Dataset.from_tensors(([1.], [2.])) + + def _model_fn(features, labels, mode): + return model_fn_lib.EstimatorSpec( + mode, + loss=features + labels, # 1 + 2 + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + est = estimator.Estimator(model_fn=_model_fn) + est.train(_input_fn, steps=1) + scores = est.evaluate(_input_fn, steps=1) + self.assertEqual(3., scores[model_fn_lib.LOSS_METRIC_KEY]) + + def test_with_none_labels(self): + + def _input_fn(): + return dataset_ops.Dataset.from_tensors([7.]) + + def _model_fn(features, labels, mode): + self.assertIsNone(labels) + return model_fn_lib.EstimatorSpec( + mode, + loss=features, # 7 + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + est = estimator.Estimator(model_fn=_model_fn) + est.train(_input_fn, steps=1) + scores = est.evaluate(_input_fn, steps=1) + self.assertEqual(7., scores[model_fn_lib.LOSS_METRIC_KEY]) + + def test_with_predict(self): + + def _input_fn(): + return dataset_ops.Dataset.from_tensors([10.]) + + def _model_fn(features, labels, mode): + _ = labels + return model_fn_lib.EstimatorSpec( + mode, + predictions=features, # 10 + loss=features, # 10 + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + est = estimator.Estimator(model_fn=_model_fn) + est.train(_input_fn, steps=1) + self.assertEqual([10.], next(est.predict(input_fn=_input_fn))) + + def test_batching(self): + + def _input_fn(): + return dataset_ops.Dataset.from_tensor_slices(([[1.], [2.]], + [[10.], [20.]])).batch(1) + + def _model_fn(features, labels, mode): + return model_fn_lib.EstimatorSpec( + mode, + predictions=features, + loss=features + (0 if labels is None else labels), # 11, 22 + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + est = estimator.Estimator(model_fn=_model_fn) + est.train(_input_fn) + scores = est.evaluate(_input_fn) + # (11 + 22)/2 = 16.5 + self.assertEqual(16.5, scores[model_fn_lib.LOSS_METRIC_KEY]) + self.assertEqual([1., 2.], list(est.predict(_input_fn))) + + class EstimatorEvaluateTest(test.TestCase): def test_input_fn_args(self): diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 29cf223724..366025a0d8 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -692,7 +692,10 @@ class _FuncGraph(ops.Graph): else: # Substitute with a placeholder. self.extra_inputs.append(x) - ph = array_ops.placeholder(x.dtype, shape=x.get_shape()) + # Hoist the new input placeholder out of any control flow context + # we're currently in. + with ops.control_dependencies(None): + ph = array_ops.placeholder(x.dtype, shape=x.get_shape()) # pylint: disable=protected-access ph._handle_data = x._handle_data # pylint: enable=protected-access diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index ba43e9199b..11f343c579 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -724,6 +724,38 @@ class FunctionTest(test.TestCase): # NOTE: We still do not support capturing control deps. _ = Foo(x) + def testCaptureInWhileLoop(self): + g = ops.Graph() + with g.as_default(): + x = constant_op.constant(1) + + @function.Defun() + def Foo(): + return control_flow_ops.while_loop(lambda i: i < 10, + lambda i: i + x, + [0]) + y = Foo() + + with self.test_session(graph=g) as sess: + self.assertEqual(sess.run(y), 10) + + def testCaptureInCond(self): + g = ops.Graph() + with g.as_default(): + x = constant_op.constant(1) + + @function.Defun(dtypes.bool) + def Foo(pred): + return control_flow_ops.cond(pred, + lambda: x, + lambda: x + 1) + y = Foo(True) + z = Foo(False) + + with self.test_session(graph=g) as sess: + self.assertEqual(sess.run(y), 1) + self.assertEqual(sess.run(z), 2) + def testStableName(self): @function.Defun() diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 434cbda7ad..ada8c30fab 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -179,12 +179,11 @@ def _ProcessInputMapParam(input_map): def _ProcessReturnElementsParam(return_elements): """Type-checks and possibly canonicalizes `return_elements`.""" - if return_elements is not None: - return_elements = tuple(return_elements) - if not all(isinstance(x, compat.bytes_or_text_types) - for x in return_elements): - raise TypeError('return_elements must be a list of strings.') - return return_elements + if return_elements is None: return None + if not all(isinstance(x, compat.bytes_or_text_types) + for x in return_elements): + raise TypeError('return_elements must be a list of strings.') + return tuple(compat.as_str(x) for x in return_elements) def _FindAttrInOpDef(attr_name, op_def): @@ -194,24 +193,125 @@ def _FindAttrInOpDef(attr_name, op_def): return None -def _PopulateTFImportGraphDefOptions(options, prefix, return_elements): +def _ConvertInputMapValues(name, input_map): + """Ensures all input map values are tensors. + + This should be called from inside the import name scope. + + Args: + name: the `name` argument passed to import_graph_def + input_map: the `input_map` argument passed to import_graph_def. + + Returns: + An possibly-updated version of `input_map`. + + Raises: + ValueError: if input map values cannot be converted due to empty name scope. + """ + if not all(isinstance(v, ops.Tensor) for v in input_map.values()): + if name == '': # pylint: disable=g-explicit-bool-comparison + raise ValueError( + 'tf.import_graph_def() requires a non-empty `name` if `input_map` ' + 'contains non-Tensor values. Try calling tf.convert_to_tensor() on ' + '`input_map` values before calling tf.import_graph_def().') + with ops.name_scope('_inputs'): + input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} + return input_map + + +def _PopulateTFImportGraphDefOptions(options, prefix, input_map, + return_elements): """Populates the TF_ImportGraphDefOptions `options`.""" c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix) + for input_src, input_dst in input_map.items(): + input_src = compat.as_str(input_src) + if input_src.startswith('^'): + src_name = compat.as_bytes(input_src[1:]) + dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access + c_api.TF_ImportGraphDefOptionsRemapControlDependency(options, src_name, + dst_op) + else: + src_name, src_idx = _ParseTensorName(input_src) + src_name = compat.as_str(src_name) + dst_output = input_dst._as_tf_output() # pylint: disable=protected-access + c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, + src_idx, dst_output) for name in return_elements or []: if ':' in name: op_name, index = _ParseTensorName(name) + op_name = compat.as_str(op_name) c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index) else: - c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name) + c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, + compat.as_str(name)) + + # TODO(skyewm): control dependencies def _ProcessNewOps(graph): """Processes the newly-added TF_Operations in `graph`.""" - for c_op in c_api_util.new_tf_operations(graph): - graph._create_op_from_tf_operation(c_op) # pylint: disable=protected-access + # Maps from a node to the names of the ops it's colocated with, if colocation + # is specified in the attributes. + colocation_pairs = {} - # TODO(skyewm): colocation logic + for c_op in c_api_util.new_tf_operations(graph): + # pylint: disable=protected-access + new_op = graph._create_op_from_tf_operation(c_op, compute_device=False) + # pylint: enable=protected-access + + colocation_names = _GetColocationNames(new_op) + if colocation_names: + colocation_pairs[new_op] = colocation_names + # Don't apply this op's device function, since colocation constraints + # override device functions. Note that this op's device may still be set + # by the loop below. + else: + with _MaybeDevice(new_op.device): + graph._apply_device_functions(new_op) # pylint: disable=protected-access + + # The following loop populates the device field of ops that are colocated + # with another op. This is implied by the colocation attribute, but we + # propagate the device field for completeness. + for op, coloc_op_list in colocation_pairs.items(): + coloc_device = None + # Find any device in the list of colocated ops that have a device, if it + # exists. We assume that if multiple ops have devices, they refer to the + # same device. Otherwise, a runtime error will occur since the colocation + # property cannot be guaranteed. + # + # One possible improvement is to try to check for compatibility of all + # devices in this list at import time here, which would require + # implementing a compatibility function for device specs in python. + for coloc_op_name in coloc_op_list: + try: + coloc_op = graph._get_operation_by_name_unsafe(coloc_op_name) # pylint: disable=protected-access + except KeyError: + raise ValueError('Specified colocation to an op that ' + 'does not exist during import: %s in %s' % ( + coloc_op_name, op.name)) + if coloc_op.device: + coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) + break + if coloc_device: + op._set_device(coloc_device) # pylint: disable=protected-access + + +def _GetColocationNames(op): + """Returns names of the ops that `op` should be colocated with.""" + colocation_names = [] + try: + class_values = op.get_attr('_class') + except ValueError: + # No _class attr + return + for val in class_values: + val = compat.as_str(val) + if val.startswith('loc:@'): + colocation_node_name = val[len('loc:@'):] + if colocation_node_name != op.name: + colocation_names.append(colocation_node_name) + return colocation_names def _GatherReturnElements(requested_return_elements, graph, results): @@ -312,17 +412,27 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, else: prefix = '' + # Generate any input map tensors inside name scope + input_map = _ConvertInputMapValues(name, input_map) + scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options - _PopulateTFImportGraphDefOptions(options, prefix, return_elements) + _PopulateTFImportGraphDefOptions(options, prefix, input_map, + return_elements) with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: - with errors.raise_exception_on_not_ok_status() as status: - results = c_api.TF_GraphImportGraphDefWithResults( - graph._c_graph, serialized, options, status) # pylint: disable=protected-access + try: + with errors.raise_exception_on_not_ok_status() as status: + results = c_api.TF_GraphImportGraphDefWithResults( + graph._c_graph, serialized, options, status) # pylint: disable=protected-access + except errors.InvalidArgumentError as e: + # Convert to ValueError for backwards compatibility. + raise ValueError(str(e)) _ProcessNewOps(graph) + # TODO(skyewm): error if unused input map key + if return_elements is None: return None else: @@ -359,16 +469,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, # more nuanced. g.graph_def_versions.CopyFrom(graph_def.versions) - if not all(isinstance(v, ops.Tensor) for v in input_map.values()): - if not scope: - # The caller must have passed `name=''`. - raise ValueError( - 'tf.import_graph_def() requires a non-empty `name` if `input_map`' - ' contains non-Tensor values. Try calling tf.convert_to_tensor() ' - 'on `input_map` values before calling tf.import_graph_def().') - with ops.name_scope('_inputs'): - input_map = {k: ops.convert_to_tensor(v) - for k, v in input_map.items()} + input_map = _ConvertInputMapValues(name, input_map) # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index 5a6187c8a6..4a215abd2e 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -201,8 +201,6 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(outer_inner_c.name, "outer/inner/c_1") def testInputMap(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - with ops.Graph().as_default(): feed_a_0 = constant_op.constant(0, dtype=dtypes.int32) feed_b_1 = constant_op.constant(1, dtype=dtypes.int32) @@ -230,8 +228,6 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(d.inputs[1], feed_b_1) def testInputMapBytes(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - with ops.Graph().as_default(): feed_a_0 = constant_op.constant(0, dtype=dtypes.int32) feed_b_1 = constant_op.constant(1, dtype=dtypes.int32) @@ -259,8 +255,6 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(d.inputs[1], feed_b_1) def testInputMapUnicode(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - with ops.Graph().as_default(): feed_a_0 = constant_op.constant(0, dtype=dtypes.int32) feed_b_1 = constant_op.constant(1, dtype=dtypes.int32) @@ -299,8 +293,6 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(b.inputs[0], a.outputs[0]) def testInputMapImplicitZerothOutput(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - with ops.Graph().as_default(): feed_a_0 = constant_op.constant(0, dtype=dtypes.int32) b, = importer.import_graph_def( @@ -453,8 +445,6 @@ class ImportGraphDefTest(test.TestCase): self.assertTrue("Input tensor 'A:0' not found" in str(e.exception)) def testMissingInputOpInGraphDefButAppearsInInputMap(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - with ops.Graph().as_default(): feed_a_0 = constant_op.constant(5.0) b, = importer.import_graph_def( @@ -589,19 +579,20 @@ class ImportGraphDefTest(test.TestCase): self.assertTrue("not found in graph_def: [A:2]" in str(e.exception)) def testInputMapTypeMismatch(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - + if ops._USE_C_API: + error_msg = ("Input 0 of node import/B was passed float from Const:0 " + "incompatible with expected int32.") + else: + error_msg = ("Cannot convert a tensor of type float32 to an input of " + "type int32.") with ops.Graph().as_default(): - with self.assertRaises(ValueError) as e: + with self.assertRaisesRegexp(ValueError, error_msg): importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'IntOutput' } node { name: 'B' op: 'IntInput' input: 'A:0' } """), input_map={"A:0": constant_op.constant(5.0)}) - self.assertTrue( - "Cannot convert a tensor of type float32 to an input of type int32." - in str(e.exception)) def testNoReturns(self): with ops.Graph().as_default() as g: @@ -651,8 +642,6 @@ class ImportGraphDefTest(test.TestCase): b.node_def.attr["_class"]) def testColocationWithDeviceFn(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - original_graph_def = self._MakeGraphDef(""" node { name: 'A' op: 'None' attr { key: '_class' @@ -674,23 +663,17 @@ class ImportGraphDefTest(test.TestCase): with ops.Graph().as_default(): with ops.device(CustomDeviceFn): - b, = importer.import_graph_def( - original_graph_def, return_elements=["B"], name="imported_graph") - - self.assertProtoEqualsVersion(""" - node { name: 'imported_graph/A' op: 'None' device: "/device:A:0" - attr { - key: '_class' value { list { s: 'loc:@imported_graph/A' } } - } - } - node { name: 'imported_graph/B' op: 'None' device: "/device:A:0" - attr { - key: '_class' value { list { s: 'loc:@imported_graph/A' } } - } }""", b.graph.as_graph_def()) - - # Test a scenario where 'A' doesn't get a device; 'A' should - # not have a device, but during runtime will get colocated with - # 'B' because of the colocation attribute. + a, b = importer.import_graph_def(original_graph_def, + return_elements=["A", "B"], + name="imported_graph") + self.assertEqual(a.device, "/device:A:0") + self.assertEqual(b.device, "/device:A:0") + self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) + self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"]) + + # Test a scenario where 'A' doesn't get a device; 'A' should not have a + # device, but during runtime will get colocated with 'B' because of the + # colocation attribute. B's device function is still overridden by A. def BDeviceFn(op): if "B" in op.name: return "/device:B:0" @@ -698,19 +681,13 @@ class ImportGraphDefTest(test.TestCase): with ops.Graph().as_default(): with ops.device(BDeviceFn): - b, = importer.import_graph_def( - original_graph_def, return_elements=["B"], name="imported_graph") - - self.assertProtoEqualsVersion(""" - node { name: 'imported_graph/A' op: 'None' - attr { - key: '_class' value { list { s: 'loc:@imported_graph/A' } } - } - } - node { name: 'imported_graph/B' op: 'None' - attr { - key: '_class' value { list { s: 'loc:@imported_graph/A' } } - } }""", b.graph.as_graph_def()) + a, b = importer.import_graph_def(original_graph_def, + return_elements=["A", "B"], + name="imported_graph") + self.assertEqual(a.device, "") + self.assertEqual(b.device, "") + self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) + self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"]) # Only A gets a device, so B inherits it implicitly. def ADeviceFn(op): @@ -720,23 +697,15 @@ class ImportGraphDefTest(test.TestCase): with ops.Graph().as_default(): with ops.device(ADeviceFn): - b, = importer.import_graph_def( - original_graph_def, return_elements=["B"], name="imported_graph") - - self.assertProtoEqualsVersion(""" - node { name: 'imported_graph/A' op: 'None' device: "/device:A:0" - attr { - key: '_class' value { list { s: 'loc:@imported_graph/A' } } - } - } - node { name: 'imported_graph/B' op: 'None' device: "/device:A:0" - attr { - key: '_class' value { list { s: 'loc:@imported_graph/A' } } - } }""", b.graph.as_graph_def()) + a, b = importer.import_graph_def(original_graph_def, + return_elements=["A", "B"], + name="imported_graph") + self.assertEqual(a.device, "/device:A:0") + self.assertEqual(b.device, "/device:A:0") + self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) + self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"]) def testMultipleColocationWithDeviceFn(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - original_graph_def = self._MakeGraphDef(""" node { name: 'A' op: 'None'} node { name: 'B' op: 'None'} @@ -757,23 +726,19 @@ class ImportGraphDefTest(test.TestCase): with ops.Graph().as_default(): with ops.device(CustomDeviceFn): - c, = importer.import_graph_def( - original_graph_def, return_elements=["C"], name="imported_graph") - - self.assertProtoEqualsVersion(""" - node { name: 'imported_graph/A' op: 'None' } - node { name: 'imported_graph/B' op: 'None' device: "/device:B:0" } - node { name: 'imported_graph/C' op: 'None' device: "/device:B:0" - attr { - key: '_class' value { - list { s: 'loc:@imported_graph/A' - s: 'loc:@imported_graph/B' } - } - } - }""", c.graph.as_graph_def()) + a, b, c = importer.import_graph_def(original_graph_def, + return_elements=["A", "B", "C"], + name="imported_graph") + self.assertEqual(a.device, "") + self.assertEqual(b.device, "/device:B:0") + self.assertEqual(c.device, "/device:B:0") + self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) + self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/B"]) + self.assertEqual(c.colocation_groups(), + [b"loc:@imported_graph/A", b"loc:@imported_graph/B"]) def testNamePrefixColocationAttrsMultipleImport(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API + if ops._USE_C_API: return # TODO(skyewm): set uniquify_names original_graph_def = self._MakeGraphDef(""" node { name: 'A' op: 'None' } @@ -800,15 +765,19 @@ class ImportGraphDefTest(test.TestCase): } }""", b.graph.as_graph_def()) def testNamePrefixColocationAttrsNotFound(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - original_graph_def = self._MakeGraphDef(""" node { name: 'B' op: 'None' attr { key: '_class' value { list { s: 'loc:@A' } } } }""") + + if ops._USE_C_API: + error_msg = "Node 'B' expects to be colocated with unknown node 'A'" + else: + error_msg = "does not exist during import" + with ops.Graph().as_default(): - with self.assertRaisesRegexp(ValueError, "does not exist during import"): + with self.assertRaisesRegexp(ValueError, error_msg): importer.import_graph_def( original_graph_def, return_elements=["B"], name="imported_graph") @@ -825,8 +794,6 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual("graph_def must be a GraphDef proto.", str(e.exception)) def testInvalidInputForInputMap(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - with ops.Graph().as_default(): with self.assertRaises(TypeError) as e: importer.import_graph_def( @@ -967,7 +934,7 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(2, len(ops_with_two_inputs)) def testGradient(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API + if ops._USE_C_API: return # TODO(skyewm): get_shape() doesn't work with ops.Graph().as_default() as g: inputs = array_ops.placeholder( @@ -1226,8 +1193,6 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(z_val, -2.0) def testImportGraphWithFunctionTwice(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - g = ops.Graph() with g.as_default(): @function.Defun() diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 60df8f82f0..13e6426447 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -35,6 +35,7 @@ from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import versions_pb2 +from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.eager import context from tensorflow.python.eager import core @@ -373,6 +374,19 @@ class Tensor(_TensorLike): A `TensorShape` representing the shape of this tensor. """ + if _USE_C_API: + graph = self._op._graph._c_graph # pylint: disable=protected-access + with errors.raise_exception_on_not_ok_status() as status: + num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(), + status) + if num_dims == -1: + dim_list = None + else: + with errors.raise_exception_on_not_ok_status() as status: + dim_list = c_api.TF_GraphGetTensorShape_wrapper( + graph, self._as_tf_output(), num_dims, status) + dim_list = [None if i == -1 else i for i in dim_list] + return tensor_shape.TensorShape(dim_list) return self._shape def __iter__(self): @@ -392,8 +406,8 @@ class Tensor(_TensorLike): yield self[i] def _shape_as_list(self): - if self._shape.ndims is not None: - return [dim.value for dim in self._shape.dims] + if self.shape.ndims is not None: + return [dim.value for dim in self.shape.dims] else: return None @@ -409,7 +423,7 @@ class Tensor(_TensorLike): Returns: Integer rank or None """ - return self._shape.ndims + return self.shape.ndims def get_shape(self): """Alias of Tensor.shape.""" @@ -440,14 +454,35 @@ class Tensor(_TensorLike): ``` Args: - shape: A `TensorShape` representing the shape of this tensor. + shape: A `TensorShape` representing the shape of this tensor, a + `TensorShapeProto`, a list, a tuple, or None. Raises: ValueError: If `shape` is not compatible with the current shape of this tensor. """ - # TODO(skyewm): call C API - self._shape = self._shape.merge_with(shape) + if not _USE_C_API: + self._shape = self._shape.merge_with(shape) # pylint: disable=protected-access + return + if not isinstance(shape, tensor_shape.TensorShape): + shape = tensor_shape.TensorShape(shape) + dim_list = [] + if shape.dims is None: + unknown_shape = True + else: + unknown_shape = False + for dim in shape.dims: + if dim.value is None: + dim_list.append(-1) + else: + dim_list.append(dim.value) + with errors.raise_exception_on_not_ok_status() as status: + c_api.TF_GraphSetTensorShape_wrapper( + self._op._graph._c_graph, # pylint: disable=protected-access + self._as_tf_output(), + dim_list, + unknown_shape, + status) @property def value_index(self): @@ -598,11 +633,6 @@ class Tensor(_TensorLike): """ return _eval_using_default_session(self, feed_dict, self.graph, session) - def _dup(self): - ret = copy.copy(self) - ret._id = uid() # pylint: disable=protected-access - return ret - # TODO(agarwal): consider getting rid of this. class _EagerTensorBase(Tensor): @@ -728,9 +758,6 @@ class _EagerTensorBase(Tensor): return new_tensor # pylint: enable=protected-access - def _dup(self): - return self._copy(device_name=self.device) - @property def shape(self): return tensor_shape.TensorShape(self._shape_tuple()) @@ -1634,8 +1661,6 @@ class Operation(object): self._id_value = self._graph._next_id() # pylint: disable=protected-access self._recompute_node_def() - self._graph._add_op(self) # pylint: disable=protected-access - def _reconstruct_sequence_inputs(self, op_def, inputs, attrs): """Regroups a flat list of input tensors into scalar and sequence inputs. @@ -1795,7 +1820,7 @@ class Operation(object): c_api.SetRequestedDevice( self._graph._c_graph, # pylint: disable=protected-access self._c_op, # pylint: disable=protected-access - _device_string(device)) + compat.as_str(_device_string(device))) else: self._node_def.device = _device_string(device) @@ -2084,7 +2109,7 @@ class Operation(object): def _set_attr(self, attr_name, attr_value): """Private method used to set an attribute in the node_def.""" - if _USE_C_API: + if self._c_op: buf = c_api.TF_NewBufferFromString( compat.as_bytes(attr_value.SerializeToString())) try: @@ -2653,11 +2678,16 @@ class Graph(object): # TODO(skyewm): fold as much of the above as possible into the C # implementation - if _USE_C_API: + if _USE_C_API or self._use_c_api_hack(): self._scoped_c_graph = c_api_util.ScopedTFGraph() else: self._scoped_c_graph = None + # TODO(apassos) remove once the C API is used by default. + def _use_c_api_hack(self): + """Temporary hack; can be overridden to force C API usage.""" + return False + def _convert_stack(self, stack, include_func_start_lineno=False): """Converts a stack extracted using _extract_stack() to a traceback stack. @@ -2986,9 +3016,14 @@ class Graph(object): # Add function to graph # pylint: disable=protected-access if self._c_graph: - assert function._c_func, ( - "Cannot add function created without C API support to graph " - "created with C API support") + # Handle functions created without using the C API. TODO(apassos,skyewm) + # remove this when all functions are generated using the C API by default + # as this will be unnecessary. + if not function._c_func: + with errors.raise_exception_on_not_ok_status() as status: + serialized = function.definition.SerializeToString() + function._c_func = c_api.TF_FunctionImportFunctionDef( + serialized, status) with errors.raise_exception_on_not_ok_status() as status: gradient = function._grad_func._c_func if function._grad_func else None c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient, @@ -3099,12 +3134,11 @@ class Graph(object): input_types=input_types, original_op=self._default_original_op, op_def=op_def) - self._create_op_helper(ret, compute_shapes=compute_shapes, compute_device=compute_device) return ret - def _create_op_from_tf_operation(self, c_op): + def _create_op_from_tf_operation(self, c_op, compute_device=True): """Creates an `Operation` in this graph from the supplied TF_Operation. This method is like create_op() except the new Operation is constructed @@ -3114,6 +3148,8 @@ class Graph(object): Args: c_op: a wrapped TF_Operation + compute_device: (Optional.) If True, device functions will be executed + to compute the device property of the Operation. Returns: An `Operation` object. @@ -3124,7 +3160,7 @@ class Graph(object): for output in tf_outputs) control_inputs = self._control_dependencies_for_inputs(input_ops) ret = Operation(c_op, self, control_inputs=control_inputs) - self._create_op_helper(ret) + self._create_op_helper(ret, compute_device=compute_device) return ret def _create_op_helper(self, op, compute_shapes=True, compute_device=True): @@ -3138,6 +3174,8 @@ class Graph(object): # compute_shapes argument. if op._c_op or compute_shapes: # pylint: disable=protected-access set_shapes_for_outputs(op) + # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed. + self._add_op(op) # Apply any additional attributes requested. Do not overwrite any existing # attributes. @@ -4517,15 +4555,11 @@ def control_dependencies(control_inputs): See @{tf.Graph.control_dependencies} for more details. - When eager execution is enabled, any callable object in the `control_inputs` - list will be called. - Args: control_inputs: A list of `Operation` or `Tensor` objects which must be executed or computed before running the operations defined in the context. Can also be `None` to clear the control - dependencies. If eager execution is enabled, any callable object in the - `control_inputs` list will be called. + dependencies. Returns: A context manager that specifies control dependencies for all @@ -4534,11 +4568,6 @@ def control_dependencies(control_inputs): if context.in_graph_mode(): return get_default_graph().control_dependencies(control_inputs) else: - if control_inputs: - # Excute any pending callables. - for control in control_inputs: - if callable(control): - control() return _NullContextmanager() @@ -4794,6 +4823,16 @@ def enable_eager_execution(config=None, device_policy=None): or if trying to create a context with nontrivial options which differ from those of the existing context. """ + if config is not None and not isinstance(config, config_pb2.ConfigProto): + raise TypeError( + "config must be a tf.ConfigProto, but got %s" % type(config)) + if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT, + context.DEVICE_PLACEMENT_WARN, + context.DEVICE_PLACEMENT_SILENT): + raise ValueError( + "device_policy must be one of None, tfe.DEVICE_PLACEMENT_EXPLICIT, " + "tfe.DEVICE_PLACEMENT_WARN, tfe.DEVICE_PLACEMENT_SILENT" + ) # pylint: disable=protected-access if context._default_mode == context.GRAPH_MODE: graph_mode_has_been_used = ( diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index cd296ccdc5..b1ad6ad744 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -274,6 +274,7 @@ class OperationTest(test_util.TensorFlowTestCase): op1 = ops.Operation( ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], [dtypes.float32_ref, dtypes.float32]) + g._add_op(op1) self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) self.assertEquals([], list(op1.inputs)) ref_t, nonref_t = op1.values() @@ -282,12 +283,14 @@ class OperationTest(test_util.TensorFlowTestCase): ops._NodeDef("RefInputFloatInput", "op2"), g, [ref_t, nonref_t], [], input_types=[dtypes.float32_ref, dtypes.float32]) + g._add_op(op2) self.assertProtoEquals( "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", op2.node_def) self.assertEquals([ref_t, nonref_t], list(op2.inputs)) op3 = ops.Operation( ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) + g._add_op(op3) self.assertProtoEquals( "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", op3.node_def) @@ -1537,7 +1540,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): self.assertEqual(future.calls, 1) else: a = constant_op.constant(1.0) - b = future + b = future() with ops.control_dependencies([a, b]): c = constant_op.constant(3.0) self.assertEqual(future.calls, 1) @@ -1876,6 +1879,24 @@ class GraphTest(test_util.TensorFlowTestCase): gc.collect() self.assertIsNone(g_ref()) + def testRunnableAfterInvalidShape(self): + with ops.Graph().as_default(): + with self.assertRaises(ValueError): + math_ops.add([1, 2], [1, 2, 3]) + a = constant_op.constant(1) + with session.Session() as sess: + sess.run(a) + + def testRunnableAfterInvalidShapeWithKernelLabelMap(self): + g = ops.Graph() + with g.as_default(): + with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): + with self.assertRaises(ValueError): + test_ops.kernel_label_required(1) + a = constant_op.constant(1) + with session.Session() as sess: + sess.run(a) + @test_util.with_c_api class AttrScopeTest(test_util.TensorFlowTestCase): @@ -2395,6 +2416,13 @@ class InputTypesTest(test_util.TensorFlowTestCase): self.assertEqual([dtypes.double, dtypes.double], z.op._input_dtypes) # pylint: enable=protected-access + def testBadArgumentsToEnableEagerExecution(self): + with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"): + ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT) + with self.assertRaisesRegexp(ValueError, "device_policy must be one of"): + c = config_pb2.ConfigProto() + ops.enable_eager_execution(c, c) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc index 25bb7af20c..dbabce0962 100644 --- a/tensorflow/python/framework/test_ops.cc +++ b/tensorflow/python/framework/test_ops.cc @@ -26,6 +26,16 @@ REGISTER_OP("KernelLabel") .Output("result: string") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("KernelLabelRequired") + .Input("input: int32") + .Output("result: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle out; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &out)); + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + REGISTER_OP("GraphDefVersion") .Output("version: int32") .SetIsStateful() @@ -104,6 +114,14 @@ REGISTER_KERNEL_BUILDER(Name("KernelLabel") .Label("overload_2"), KernelLabelOp<OVERLOAD_2_LABEL>); +// All "KernelLabelRequired" kernels have labels +REGISTER_KERNEL_BUILDER( + Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_1"), + KernelLabelOp<OVERLOAD_1_LABEL>); +REGISTER_KERNEL_BUILDER( + Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_2"), + KernelLabelOp<OVERLOAD_2_LABEL>); + class GraphDefVersionOp : public OpKernel { public: explicit GraphDefVersionOp(OpKernelConstruction* ctx) diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i index 7dd79f7c82..8f72a425c3 100644 --- a/tensorflow/python/grappler/item.i +++ b/tensorflow/python/grappler/item.i @@ -120,7 +120,7 @@ static PyObject* TF_GetOpProperties(GItem item) { Py_RETURN_NONE; } tensorflow::grappler::GraphProperties properties(*item); - tensorflow::Status status = properties.InferStatically(); + tensorflow::Status status = properties.InferStatically(false); if (!status.ok()) { Py_RETURN_NONE; } diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 626e0502cb..50735fb567 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -190,7 +190,7 @@ class LayoutOptimizerTest(test.TestCase): self.assertEqual(expected_num_transposes, num_transposes) self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Reshape-0', nodes) - self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Relu_1-MaxPool_1', + self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Relu_1-MaxPool_1-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc index 7d365c3be9..da5b03234e 100644 --- a/tensorflow/python/grappler/model_analyzer.cc +++ b/tensorflow/python/grappler/model_analyzer.cc @@ -27,7 +27,7 @@ ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {} Status ModelAnalyzer::GenerateReport(std::ostream& os) { GraphProperties properties(item_); - TF_RETURN_IF_ERROR(properties.InferStatically()); + TF_RETURN_IF_ERROR(properties.InferStatically(false)); for (const auto& node : item_.MainOpsFanin()) { PrintNodeInfo(node, properties, os); diff --git a/tensorflow/python/keras/_impl/keras/callbacks_test.py b/tensorflow/python/keras/_impl/keras/callbacks_test.py index 9c17fbb4a7..79dfcd1bb6 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks_test.py +++ b/tensorflow/python/keras/_impl/keras/callbacks_test.py @@ -685,8 +685,8 @@ class KerasCallbacksTest(test.TestCase): # fit w/o validation data should raise ValueError if histogram_freq > 0 cbs = callbacks_factory(histogram_freq=1) with self.assertRaises(ValueError): - model.fit(x_train, y_train, batch_size=BATCH_SIZE, - callbacks=cbs, epochs=3) + model.fit( + x_train, y_train, batch_size=BATCH_SIZE, callbacks=cbs, epochs=3) for cb in cbs: cb.on_train_end() @@ -695,8 +695,8 @@ class KerasCallbacksTest(test.TestCase): # histogram_freq > 0 cbs = callbacks_factory(histogram_freq=1) with self.assertRaises(ValueError): - model.fit_generator(data_generator(True), len(x_train), epochs=2, - callbacks=cbs) + model.fit_generator( + data_generator(True), len(x_train), epochs=2, callbacks=cbs) for cb in cbs: cb.on_train_end() @@ -705,10 +705,13 @@ class KerasCallbacksTest(test.TestCase): # histogram_freq > 0 cbs = callbacks_factory(histogram_freq=1) with self.assertRaises(ValueError): - model.fit_generator(data_generator(True), len(x_train), epochs=2, - validation_data=data_generator(False), - validation_steps=1, - callbacks=cbs) + model.fit_generator( + data_generator(True), + len(x_train), + epochs=2, + validation_data=data_generator(False), + validation_steps=1, + callbacks=cbs) for cb in cbs: cb.on_train_end() diff --git a/tensorflow/python/keras/_impl/keras/utils/io_utils.py b/tensorflow/python/keras/_impl/keras/utils/io_utils.py index 2003e19a0a..a8fc18c17a 100644 --- a/tensorflow/python/keras/_impl/keras/utils/io_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/io_utils.py @@ -78,7 +78,7 @@ class HDF5Matrix(object): def __len__(self): return self.end - self.start - def __del__(self): + def __del__(self): self._f.close() def __getitem__(self, key): diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 2ec162578c..f6721de32a 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -676,6 +676,7 @@ cuda_py_test( "//tensorflow/python:gradients", "//tensorflow/python:state_ops", "//tensorflow/python:variables", + "//tensorflow/python:resource_variable_ops", ], tags = ["noasan"], # http://b/32635055 ) diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py index 3b71586b55..8e9d75667d 100644 --- a/tensorflow/python/kernel_tests/constant_op_eager_test.py +++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py @@ -237,6 +237,39 @@ class ConstantTest(test.TestCase): self._testAll((1, x)) self._testAll((x, 1)) + def testInvalidLength(self): + + class BadList(list): + + def __init__(self): + super(BadList, self).__init__([1, 2, 3]) # pylint: disable=invalid-length-returned + + def __len__(self): + return -1 + + with self.assertRaisesRegexp(ValueError, "should return >= 0"): + constant_op.constant([BadList()]) + with self.assertRaisesRegexp(ValueError, "mixed types"): + constant_op.constant([1, 2, BadList()]) + with self.assertRaisesRegexp(ValueError, "should return >= 0"): + constant_op.constant(BadList()) + with self.assertRaisesRegexp(ValueError, "should return >= 0"): + constant_op.constant([[BadList(), 2], 3]) + with self.assertRaisesRegexp(ValueError, "should return >= 0"): + constant_op.constant([BadList(), [1, 2, 3]]) + with self.assertRaisesRegexp(ValueError, "should return >= 0"): + constant_op.constant([BadList(), []]) + + # TODO(allenl, josh11b): These cases should return exceptions rather than + # working (currently shape checking only checks the first element of each + # sequence recursively). Maybe the first one is fine, but the second one + # silently truncating is rather bad. + + # with self.assertRaisesRegexp(ValueError, "should return >= 0"): + # constant_op.constant([[3, 2, 1], BadList()]) + # with self.assertRaisesRegexp(ValueError, "should return >= 0"): + # constant_op.constant([[], BadList()]) + def testSparseValuesRaiseErrors(self): with self.assertRaisesRegexp(ValueError, "non-rectangular Python sequence"): constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32) diff --git a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py index 1679857bd5..be299beee4 100644 --- a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py +++ b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py @@ -42,17 +42,21 @@ class Conv2DBackpropFilterGradTest(test.TestCase): filter_shape = [3, 3, 4, 6] # Make a convolution op with the current settings, just to easily get # the shape of the output. - conv_out = nn_ops.conv2d(in_val, - array_ops.zeros(filter_shape), - [1, stride, stride, 1], padding) + conv_out = nn_ops.conv2d( + in_val, + array_ops.zeros(filter_shape), + strides=[1, stride, stride, 1], + padding=padding) out_backprop_shape = conv_out.get_shape().as_list() out_backprop_val = constant_op.constant( 2 * np.random.random_sample(out_backprop_shape) - 1, dtype=dtypes.float32) - output = nn_ops.conv2d_backprop_filter(in_val, filter_shape, - out_backprop_val, - [1, stride, stride, 1], - padding) + output = nn_ops.conv2d_backprop_filter( + in_val, + filter_shape, + out_backprop_val, + strides=[1, stride, stride, 1], + padding=padding) err = gradient_checker.compute_gradient_error( [in_val, out_backprop_val], [in_shape, out_backprop_shape], output, filter_shape) @@ -60,6 +64,42 @@ class Conv2DBackpropFilterGradTest(test.TestCase): err_tolerance = 2e-3 self.assertLess(err, err_tolerance) + def testGradientDilatedConv(self): + if test.is_gpu_available(cuda_only=True): + with self.test_session(use_gpu=True): + for padding in ["SAME", "VALID"]: + for stride in [1, 2]: + np.random.seed(1) + in_shape = [5, 8, 6, 4] + in_val = constant_op.constant( + 2 * np.random.random_sample(in_shape) - 1, dtype=dtypes.float32) + filter_shape = [3, 3, 4, 6] + # Make a convolution op with the current settings, + # just to easily get the shape of the output. + conv_out = nn_ops.conv2d( + in_val, + array_ops.zeros(filter_shape), + dilations=[1, 2, 2, 1], + strides=[1, stride, stride, 1], + padding=padding) + out_backprop_shape = conv_out.get_shape().as_list() + out_backprop_val = constant_op.constant( + 2 * np.random.random_sample(out_backprop_shape) - 1, + dtype=dtypes.float32) + output = nn_ops.conv2d_backprop_filter( + in_val, + filter_shape, + out_backprop_val, + dilations=[1, 2, 2, 1], + strides=[1, stride, stride, 1], + padding=padding) + err = gradient_checker.compute_gradient_error( + [in_val, out_backprop_val], [in_shape, out_backprop_shape], + output, filter_shape) + print("conv2d_backprop_filter gradient err = %g " % err) + err_tolerance = 2e-3 + self.assertLess(err, err_tolerance) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 22e5400c37..bf7245a2ae 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import os import time @@ -32,6 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops @@ -240,6 +242,77 @@ class Conv2DTest(test.TestCase): for i in range(1, len(values)): self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5) + def _ComputeReferenceDilatedConv(self, tensor_in_sizes, filter_in_sizes, + stride, dilation, padding, data_format, + use_gpu): + total_size_1 = 1 + total_size_2 = 1 + for s in tensor_in_sizes: + total_size_1 *= s + for s in filter_in_sizes: + total_size_2 *= s + + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] + x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] + with test_util.device(use_gpu): + t1 = constant_op.constant(x1, shape=tensor_in_sizes) + t2 = constant_op.constant(x2, shape=filter_in_sizes) + if isinstance(stride, collections.Iterable): + strides = list(stride) + else: + strides = [stride, stride] + if data_format == "NCHW": + t1 = test_util.NHWCToNCHW(t1) + full_strides = [1, 1] + strides + full_dilation = [1, 1] + dilation + else: + full_strides = [1] + strides + [1] + full_dilation = [1] + dilation + [1] + expected = nn_ops.convolution( + t1, + t2, + padding=padding, + strides=strides, + dilation_rate=dilation, + data_format=data_format) + computed = nn_ops.conv2d( + t1, + t2, + strides=full_strides, + dilations=full_dilation, + padding=padding, + data_format=data_format) + if data_format == "NCHW": + expected = test_util.NCHWToNHWC(expected) + computed = test_util.NCHWToNHWC(computed) + return expected, computed + + def _VerifyDilatedConvValues(self, tensor_in_sizes, filter_in_sizes, strides, + padding, dilations): + expected_results = [] + computed_results = [] + default_dilations = (dilations[0] == 1 and dilations[1] == 1) + for data_format, use_gpu in GetTestConfigs(): + # If any dilation rate is larger than 1, only do test on the GPU + # because we currently do not have a CPU implementation for arbitrary + # dilation rates. + if default_dilations or use_gpu: + expected, computed = self._ComputeReferenceDilatedConv( + tensor_in_sizes, filter_in_sizes, strides, dilations, padding, + data_format, use_gpu) + expected_results.append(expected) + computed_results.append(computed) + tolerance = 1e-2 if use_gpu else 1e-5 + expected_values = self.evaluate(expected_results) + computed_values = self.evaluate(computed_results) + for e_value, c_value in zip(expected_values, computed_values): + print("expected = ", e_value) + print("actual = ", c_value) + self.assertAllClose( + e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-6) + def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding, expected): tensors = [] @@ -280,6 +353,16 @@ class Conv2DTest(test.TestCase): expected=expected_output) @test_util.run_in_graph_and_eager_modes() + def testConv2D2x2Filter2x1Dilation(self): + if test.is_gpu_available(cuda_only=True): + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 4, 4, 1], + filter_in_sizes=[2, 2, 1, 1], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") + + @test_util.run_in_graph_and_eager_modes() def testConv2DEmpty(self): expected_output = [] self._VerifyValues( @@ -290,6 +373,16 @@ class Conv2DTest(test.TestCase): expected=expected_output) @test_util.run_in_graph_and_eager_modes() + def testConv2DEmptyDilation(self): + if test.is_gpu_available(cuda_only=True): + self._VerifyDilatedConvValues( + tensor_in_sizes=[0, 2, 3, 3], + filter_in_sizes=[1, 1, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") + + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0] @@ -301,6 +394,16 @@ class Conv2DTest(test.TestCase): expected=expected_output) @test_util.run_in_graph_and_eager_modes() + def testConv2D2x2FilterDilation(self): + if test.is_gpu_available(cuda_only=True): + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[2, 2, 3, 3], + strides=[1, 1], + dilations=[1, 2], + padding="VALID") + + @test_util.run_in_graph_and_eager_modes() def testConv2D1x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [ @@ -315,6 +418,16 @@ class Conv2DTest(test.TestCase): expected=expected_output) @test_util.run_in_graph_and_eager_modes() + def testConv2D1x2FilterDilation(self): + if test.is_gpu_available(cuda_only=True): + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[1, 2, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") + + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2FilterStride2(self): expected_output = [2271.0, 2367.0, 2463.0] self._VerifyValues( @@ -386,13 +499,23 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=[50, 60]) - # TODO this currently fails. - # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1], - # filter_in_sizes=[2, 2, 1, 1], - # strides=[4, 4], padding="SAME", - # expected=[72, 112, 392, 432]) + @test_util.run_in_graph_and_eager_modes() + def testConv2DKernelSizeMatchesInputSizeDilation(self): + if test.is_gpu_available(cuda_only=True): + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 3, 3, 1], + filter_in_sizes=[2, 2, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID") + + # TODO this currently fails. + # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1], + # filter_in_sizes=[2, 2, 1, 1], + # strides=[4, 4], padding="SAME", + # expected=[72, 112, 392, 432]) - # Testing for backprops + # Testing for backprops def _RunAndVerifyBackpropInput(self, input_sizes, filter_sizes, output_sizes, strides, padding, expected, data_format, use_gpu, err): @@ -724,6 +847,255 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) + # Testing for backprops + def _RunAndVerifyBackpropInputDilation(self, input_sizes, filter_sizes, + output_sizes, strides, dilations, + padding, data_format, use_gpu, err): + total_input_size = 1 + total_filter_size = 1 + for s in input_sizes: + total_input_size *= s + for s in filter_sizes: + total_filter_size *= s + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x1 = [f * 1.0 for f in range(1, total_input_size + 1)] + x2 = [f * 1.0 for f in range(1, total_filter_size + 1)] + default_dilations = (dilations[0] == 1 and dilations[1] == 1) + if default_dilations or use_gpu: + with self.test_session(use_gpu=use_gpu) as sess: + if data_format == "NCHW": + input_sizes = test_util.NHWCToNCHW(input_sizes) + t1 = constant_op.constant(x1, shape=input_sizes) + t2 = constant_op.constant(x2, shape=filter_sizes) + full_strides = [1] + strides + [1] + full_dilations = [1] + dilations + [1] + if data_format == "NCHW": + full_strides = test_util.NHWCToNCHW(full_strides) + full_dilations = test_util.NHWCToNCHW(full_dilations) + conv_forward = nn_ops.conv2d( + t1, + t2, + strides=full_strides, + dilations=full_dilations, + padding=padding, + data_format=data_format) + conv_forward_2 = nn_ops.convolution( + t1, + t2, + padding=padding, + strides=strides, + dilation_rate=dilations, + data_format=data_format) + if data_format == "NCHW": + conv_forward = test_util.NCHWToNHWC(conv_forward) + conv_forward_2 = test_util.NCHWToNHWC(conv_forward_2) + conv = gradients_impl.gradients(conv_forward, t1)[0] + conv_2 = gradients_impl.gradients(conv_forward_2, t1)[0] + # "values" consists of two tensors for two backprops + value = sess.run(conv) + value_2 = sess.run(conv_2) + self.assertShapeEqual(value, conv) + self.assertShapeEqual(value_2, conv_2) + print("expected = ", value_2) + print("actual = ", value) + self.assertArrayNear(value_2.flatten(), value.flatten(), err) + + # Testing for backprops + def _RunAndVerifyBackpropFilterDilation(self, input_sizes, filter_sizes, + output_sizes, strides, dilations, + padding, data_format, use_gpu, err): + total_input_size = 1 + total_filter_size = 1 + for s in input_sizes: + total_input_size *= s + for s in filter_sizes: + total_filter_size *= s + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x1 = [f * 1.0 for f in range(1, total_input_size + 1)] + x2 = [f * 1.0 for f in range(1, total_filter_size + 1)] + default_dilations = (dilations[0] == 1 and dilations[1] == 1) + if default_dilations or use_gpu: + with self.test_session(use_gpu=use_gpu) as sess: + if data_format == "NCHW": + input_sizes = test_util.NHWCToNCHW(input_sizes) + t1 = constant_op.constant(x1, shape=input_sizes) + t2 = constant_op.constant(x2, shape=filter_sizes) + full_strides = [1] + strides + [1] + full_dilations = [1] + dilations + [1] + if data_format == "NCHW": + full_strides = test_util.NHWCToNCHW(full_strides) + full_dilations = test_util.NHWCToNCHW(full_dilations) + conv_forward = nn_ops.conv2d( + t1, + t2, + strides=full_strides, + dilations=full_dilations, + padding=padding, + data_format=data_format) + conv_forward_2 = nn_ops.convolution( + t1, + t2, + padding=padding, + strides=strides, + dilation_rate=dilations, + data_format=data_format) + if data_format == "NCHW": + conv_forward = test_util.NCHWToNHWC(conv_forward) + conv_forward_2 = test_util.NCHWToNHWC(conv_forward_2) + conv = gradients_impl.gradients(conv_forward, t2)[0] + conv_2 = gradients_impl.gradients(conv_forward, t2)[0] + value = sess.run(conv) + value_2 = sess.run(conv_2) + self.assertShapeEqual(value, conv) + self.assertShapeEqual(value_2, conv_2) + print("expected = ", value_2) + print("actual = ", value) + self.assertArrayNear(value_2.flatten(), value.flatten(), err) + + def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): + if test.is_gpu_available(cuda_only=True): + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropFilterDilation( + input_sizes=[1, 3, 6, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[1, 1, 5, 1], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + data_format=data_format, + use_gpu=use_gpu, + err=1e-5) + + def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): + if test.is_gpu_available(cuda_only=True): + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropFilterDilation( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[1, 1, 2, 1], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + data_format=data_format, + use_gpu=use_gpu, + err=1e-5) + + def testConv2DEmptyBackpropFilterDilation1x2(self): + if test.is_gpu_available(cuda_only=True): + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropFilterDilation( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 0], + output_sizes=[1, 1, 2, 0], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + data_format=data_format, + use_gpu=use_gpu, + err=1e-5) + + def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): + if test.is_gpu_available(cuda_only=True): + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropFilterDilation( + input_sizes=[1, 3, 4, 3], + filter_sizes=[2, 2, 3, 3], + output_sizes=[1, 1, 2, 3], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + data_format=data_format, + use_gpu=use_gpu, + err=1e-5) + + def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): + if test.is_gpu_available(cuda_only=True): + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropFilterDilation( + input_sizes=[1, 3, 3, 1], + filter_sizes=[2, 2, 1, 2], + output_sizes=[1, 1, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + data_format=data_format, + use_gpu=use_gpu, + err=1e-5) + + def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): + if test.is_gpu_available(cuda_only=True): + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropInputDilation( + input_sizes=[1, 3, 6, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[1, 1, 5, 1], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + data_format=data_format, + use_gpu=use_gpu, + err=1e-5) + + def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): + if test.is_gpu_available(cuda_only=True): + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropInputDilation( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[1, 1, 2, 1], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + data_format=data_format, + use_gpu=use_gpu, + err=1e-5) + + def testConv2DEmptyBackpropInputDilation1x2(self): + if test.is_gpu_available(cuda_only=True): + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropInputDilation( + input_sizes=[0, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[0, 1, 2, 1], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + data_format=data_format, + use_gpu=use_gpu, + err=1e-5) + + def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): + if test.is_gpu_available(cuda_only=True): + for (data_format, use_gpu) in GetTestConfigs(): + # The GPU version of this test is not very stable. So adjusting the + # error threshold to 1e-4. + self._RunAndVerifyBackpropInputDilation( + input_sizes=[1, 3, 2, 3], + filter_sizes=[2, 2, 3, 3], + output_sizes=[1, 1, 2, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + data_format=data_format, + use_gpu=use_gpu, + err=1e-4) + + def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self): + if test.is_gpu_available(cuda_only=True): + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropInputDilation( + input_sizes=[1, 3, 3, 1], + filter_sizes=[2, 2, 1, 2], + output_sizes=[1, 1, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + data_format=data_format, + use_gpu=use_gpu, + err=1e-5) + # Gradient checkers def ConstructAndTestGradient(self, batch, input_rows, input_cols, filter_rows, filter_cols, in_depth, out_depth, stride_rows, @@ -1457,6 +1829,22 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding, return Test +def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding): + + def Test(self): + if test.is_gpu_available(cuda_only=True) and stride == 1: + tf_logging.info("Testing InceptionFwd with dilations %s", + (input_size, filter_size, stride, padding)) + self._VerifyDilatedConvValues( + tensor_in_sizes=input_size, + filter_in_sizes=filter_size, + strides=[stride, stride], + dilations=[2, 2], + padding=padding) + + return Test + + def GetInceptionBackInputTest(input_size, filter_size, output_size, stride, padding, gpu_only=False): @@ -1497,6 +1885,10 @@ if __name__ == "__main__": test_util.run_in_graph_and_eager_modes()( GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))) + setattr( + Conv2DTest, "testInceptionFwdDilatedConv_" + str(index), + test_util.run_in_graph_and_eager_modes()(GetInceptionFwdDilatedConvTest( + input_size_, filter_size_, stride_, padding_))) setattr(Conv2DTest, "testInceptionBackInput_" + str(index), test_util.run_in_graph_and_eager_modes()( GetInceptionBackInputTest(input_size_, filter_size_, @@ -1519,6 +1911,9 @@ if __name__ == "__main__": setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused", test_util.run_in_graph_and_eager_modes()( GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))) + setattr(Conv2DTest, "testInceptionFwdDilatedConv_No_Winograd_Nonfused", + test_util.run_in_graph_and_eager_modes()( + GetInceptionFwdDilatedConvTest(ishape, fshape, 1, "SAME"))) setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused", test_util.run_in_graph_and_eager_modes()( GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME", diff --git a/tensorflow/python/kernel_tests/decode_bmp_op_test.py b/tensorflow/python/kernel_tests/decode_bmp_op_test.py index c086f46170..c67c26b7be 100644 --- a/tensorflow/python/kernel_tests/decode_bmp_op_test.py +++ b/tensorflow/python/kernel_tests/decode_bmp_op_test.py @@ -68,28 +68,68 @@ class DecodeBmpOpTest(test.TestCase): def testGrayscale(self): img_bytes = [[[255], [0]], [[255], [0]]] encoded_bytes = [ - 0x42, 0x40, - 0x3d, 0, 0, 0, - 0, 0, - 0, 0, - 0x36, 0, 0, 0, - 0x28, 0, 0, 0, - 0x2, 0, 0, 0, - 0x2, 0, 0, 0, - 0x1, 0, - 0x8, 0, - 0, 0, 0, 0, - 0x10, 0, 0, 0, - 0x13, 0xb, 0, 0, - 0x13, 0xb, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, + 0x42, + 0x40, + 0x3d, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0x36, + 0, + 0, + 0, + 0x28, + 0, + 0, + 0, + 0x2, + 0, + 0, + 0, + 0x2, + 0, + 0, + 0, + 0x1, + 0, + 0x8, + 0, + 0, + 0, + 0, + 0, + 0x10, + 0, + 0, + 0, + 0x13, + 0xb, + 0, + 0, + 0x13, + 0xb, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, 0xff, 0, - 0, 0, + 0, + 0, 0xff, 0, - 0, 0, + 0, + 0, ] byte_string = bytes(bytearray(encoded_bytes)) @@ -100,54 +140,6 @@ class DecodeBmpOpTest(test.TestCase): decoded = decode.eval() self.assertAllEqual(decoded, img_bytes) - def testIncompleteHeader(self): - # Encoded BMP bytes from Wikipedia - encoded_bytes = [ - 0x42, 0x40, - 0x46, 0, 0, 0, - ] - - byte_string = bytes(bytearray(encoded_bytes)) - img_in = constant_op.constant(byte_string, dtype=dtypes.string) - decode = array_ops.squeeze(image_ops.decode_bmp(img_in)) - - with self.test_session(): - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "requires at least 32 bytes to find the header"): - decoded = decode.eval() - - def testIncompleteBody(self): - # Encoded BMP bytes from Wikipedia - encoded_bytes = [ - 0x42, 0x40, - 0x46, 0, 0, 0, - 0, 0, - 0, 0, - 0x36, 0, 0, 0, - 0x28, 0, 0, 0, - 0x2, 0, 0, 0, - 0x2, 0, 0, 0, - 0x1, 0, - 0x18, 0, - 0, 0, 0, 0, - 0x10, 0, 0, 0, - 0x13, 0xb, 0, 0, - 0x13, 0xb, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0xff, - 0xff, 0xff, 0xff, - 0, 0, - ] - - byte_string = bytes(bytearray(encoded_bytes)) - img_in = constant_op.constant(byte_string, dtype=dtypes.string) - decode = array_ops.squeeze(image_ops.decode_bmp(img_in)) - - with self.test_session(): - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "requires at least 68 bytes, got 62 bytes"): - decoded = decode.eval() if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py index edea9c9027..646324cb95 100644 --- a/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py +++ b/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py @@ -25,10 +25,11 @@ from tensorflow.python.platform import test class PrefetchDatasetTest(test.TestCase): + def testBufferSize(self): buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(10).prefetch( - buffer_size=buffer_size).make_initializable_iterator() + buffer_size=buffer_size).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -42,7 +43,7 @@ class PrefetchDatasetTest(test.TestCase): def testInvalidBufferSize(self): buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(10).prefetch( - buffer_size=buffer_size).make_initializable_iterator() + buffer_size=buffer_size).make_initializable_iterator() init_op = iterator.initializer with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): diff --git a/tensorflow/python/kernel_tests/random/multinomial_op_test.py b/tensorflow/python/kernel_tests/random/multinomial_op_test.py index ca48ba6cad..a9dc7b7de0 100644 --- a/tensorflow/python/kernel_tests/random/multinomial_op_test.py +++ b/tensorflow/python/kernel_tests/random/multinomial_op_test.py @@ -57,12 +57,14 @@ class MultinomialTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testSmallEntropy(self): random_seed.set_random_seed(1618) - with test_util.device(use_gpu=True): - # A logit value of -10 corresponds to a probability of ~5e-5. - logits = constant_op.constant([[-10., 10., -10.], [-10., -10., 10.]]) - num_samples = 1000 - samples = self.evaluate(random_ops.multinomial(logits, num_samples)) - self.assertAllEqual([[1] * num_samples, [2] * num_samples], samples) + for output_dtype in [np.int32, np.int64]: + with test_util.device(use_gpu=True): + # A logit value of -10 corresponds to a probability of ~5e-5. + logits = constant_op.constant([[-10., 10., -10.], [-10., -10., 10.]]) + num_samples = 1000 + samples = self.evaluate(random_ops.multinomial( + logits, num_samples, output_dtype=output_dtype)) + self.assertAllEqual([[1] * num_samples, [2] * num_samples], samples) def testOneOpMultipleStepsIndependent(self): with self.test_session(use_gpu=True) as sess: diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index a79d66e988..d7bde04230 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -157,6 +158,20 @@ class StatefulScatterNdTest(test.TestCase): result = sess.run(scatter) self.assertAllClose(result, expected) + def testSimpleResource(self): + indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) + updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) + ref = resource_variable_ops.ResourceVariable( + [0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32) + expected = np.array([0, 11, 0, 10, 9, 0, 0, 12]) + scatter = state_ops.scatter_nd_update(ref, indices, updates) + init = variables.global_variables_initializer() + + with self.test_session(use_gpu=True) as sess: + sess.run(init) + sess.run(scatter) + self.assertAllClose(ref.eval(), expected) + def testSimple2(self): indices = constant_op.constant([[1, 0], [1, 1]], dtype=dtypes.int32) updates = constant_op.constant([11., 12.], dtype=dtypes.float32) diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index 99f9f09690..fd58cdb170 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -266,6 +266,27 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): self.assertAllClose(np_ans, tf_ans) self.assertShapeEqual(np_ans, s) + def testNumSegmentsTypes(self): + dtypes = [dtypes_lib.int32, dtypes_lib.int64] + indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) + num_segments = 12 + for indices in indices_flat, indices_flat.reshape(5, 2): + shape = indices.shape + (2,) + for dtype in dtypes: + with self.test_session(use_gpu=True): + tf_x, np_x = self._input(shape) + num_segments_constant = constant_op.constant( + num_segments, dtype=dtype) + np_ans = self._segmentReduce( + indices, np_x, np.add, op2=None, num_out_rows=num_segments) + s = math_ops.unsorted_segment_sum( + data=tf_x, + segment_ids=indices, + num_segments=num_segments_constant) + tf_ans = s.eval() + self.assertAllClose(np_ans, tf_ans) + self.assertShapeEqual(np_ans, s) + def testGradientSegmentSum(self): num_cols = 2 indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) diff --git a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py index 78c113f514..d1a90952c7 100644 --- a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py @@ -254,8 +254,8 @@ class SerializeSparseTest(test.TestCase): serialized_concat, dtype=dtypes.int32) with self.assertRaisesOpError( - r"Inconsistent rank across SparseTensors: rank prior to " - r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"): + r"Inconsistent shape across SparseTensors: rank prior to " + r"SparseTensor\[1\] was: 2 but rank of SparseTensor\[1\] is: 3"): sess.run(sp_deserialized, {sp_input0: input0_val, sp_input1: input1_val}) diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index 40c0ade62a..f0354374ac 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -34,9 +34,10 @@ from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent -def variable_scoped_function(): +def variable_scoped_function(trainable=True): return variable_scope.get_variable( - "dummy", shape=[1], initializer=init_ops.zeros_initializer()) + "dummy", shape=[1], trainable=trainable, + initializer=init_ops.zeros_initializer()) def internally_variable_scoped_function(scope_name): @@ -413,7 +414,7 @@ class TemplateTest(test.TestCase): self.assertEqual(custom_getter_count[0], 2) # Test that custom getter is called when the variable scope is created - # during construction + # during construction custom_getter_count[0] = 0 tmpl2 = template.make_template( "s2", @@ -539,6 +540,36 @@ class TemplateTest(test.TestCase): # Ensure we can get the scopes before either template is actually called. self.assertEqual(1, len(ta.trainable_variables)) self.assertEqual(1, len(tb.trainable_variables)) + # None non-trainable variable was created. + self.assertEqual([], list(ta.non_trainable_variables)) + self.assertEqual([], list(tb.non_trainable_variables)) + # Ensure variables returns all the variables. + self.assertEqual(1, len(ta.variables)) + self.assertEqual(1, len(tb.variables)) + + @test_util.run_in_graph_and_eager_modes() + def test_non_trainable_variables(self): + # Make sure non_trainable_variables are created. + with variable_scope.variable_scope("foo2"): + ta = template.make_template("a", variable_scoped_function, + trainable=True) + tb = template.make_template("b", variable_scoped_function, + trainable=False) + # Initially there are not variables created. + self.assertEqual([], list(ta.variables)) + self.assertEqual([], list(tb.variables)) + # After calling there are variables created. + ta() + tb() + # Check the trainable and non_trainable variables. + self.assertEqual(1, len(ta.trainable_variables)) + self.assertEqual([], list(ta.non_trainable_variables)) + + self.assertEqual([], list(tb.trainable_variables)) + self.assertEqual(1, len(tb.non_trainable_variables)) + # Ensure variables returns all the variables. + self.assertEqual(1, len(ta.variables)) + self.assertEqual(1, len(tb.variables)) # TODO(apassos) handle local variables in Eager def test_local_variables(self): diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 6be2bc3e76..c083f8a5d2 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -103,10 +103,16 @@ class Layer(object): self.built = False self.input_spec = None + if activity_regularizer and context.in_eager_mode(): + raise ValueError( + ('Activity regularization is not supported when executing eagerly. ' + 'Got activity_regularizer=%s') % (activity_regularizer,)) self._activity_regularizer = activity_regularizer self._trainable_weights = [] self._non_trainable_weights = [] self._updates = [] + # When executing eagerly, _losses is a list of zero-argument lambdas which + # return tensors. When using graph execution, _losses is a list of ops. self._losses = [] self._reuse = kwargs.get('_reuse') self._graph = ops.get_default_graph() @@ -287,9 +293,22 @@ class Layer(object): @property def losses(self): + """Losses which are associated with this `Layer`. + + Note that when executing eagerly, getting this property evaluates + regularizers. When using graph execution, variable regularization ops have + already been created and are simply returned here. + + Returns: + A list of tensors. + """ if context.in_eager_mode(): - raise RuntimeError('Layer.losses not supported in Eager mode.') - return self._losses + # _losses may only contain variable regularization losses when executing + # eagerly, and they have been saved as lambdas to be executed when + # requested. + return [regularizer() for regularizer in self._losses] + else: + return self._losses def add_loss(self, losses, inputs=None): """Add loss tensor(s), potentially dependent on layer inputs. @@ -303,6 +322,11 @@ class Layer(object): The `get_losses_for` method allows to retrieve the losses relevant to a specific set of inputs. + Note that `add_loss` is not supported when executing eagerly. Instead, + variable regularizers may be added through `add_variable`. Activity + regularization is not supported directly (but such losses may be returned + from `Layer.call()`). + Arguments: losses: Loss tensor, or list/tuple of tensors. inputs: Optional input tensor(s) that the loss(es) depend on. Must @@ -462,16 +486,8 @@ class Layer(object): Raises: RuntimeError: If called in Eager mode with regularizers. """ - # Note that we currently don't support variable regularization in Eager - # mode. An alternative is for users to directly compute these losses before - # performing a backward pass. if context.in_graph_mode(): existing_variables = set(tf_variables.global_variables()) - else: - existing_variables = [] - if regularizer is not None: - raise RuntimeError('Variable regularization not supported in Eager ' - 'mode.') if dtype is None: dtype = self.dtype or dtypes.float32 @@ -486,28 +502,39 @@ class Layer(object): constraint=constraint, trainable=trainable and self.trainable, partitioner=partitioner) - if (context.in_graph_mode() and trainable and self.trainable - and variable not in tf_variables.trainable_variables()): - # A custom getter / variable scope overrode the trainable flag. - trainable = False - if variable in existing_variables: - return variable - if regularizer: - # To match the behavior of tf.get_variable(), we only - # apply regularization if the variable is newly created. - if isinstance(variable, tf_variables.PartitionedVariable): - for v in variable: - with ops.colocate_with(v.op): + if context.in_graph_mode(): + if (trainable and self.trainable + and variable not in tf_variables.trainable_variables()): + # A custom getter / variable scope overrode the trainable flag. + trainable = False + if variable in existing_variables: + return variable + if regularizer: + # To match the behavior of tf.get_variable(), we only + # apply regularization if the variable is newly created. + if isinstance(variable, tf_variables.PartitionedVariable): + for v in variable: + with ops.colocate_with(v.op): + with ops.name_scope(name + '/Regularizer'): + regularization = regularizer(v) + if regularization is not None: + self.add_loss(regularization) + else: + with ops.colocate_with(variable.op): with ops.name_scope(name + '/Regularizer'): - regularization = regularizer(v) + regularization = regularizer(variable) if regularization is not None: self.add_loss(regularization) - else: - with ops.colocate_with(variable.op): - with ops.name_scope(name + '/Regularizer'): - regularization = regularizer(variable) - if regularization is not None: - self.add_loss(regularization) + elif regularizer: + if isinstance(variable, tf_variables.PartitionedVariable): + raise RuntimeError( + 'Partitioned variable regularization is not yet supported when ' + 'executing eagerly. File a feature request is this is ' + 'important to you.') + # Save a zero-argument lambda which runs the regularizer on the + # variable, to be executed when `Layer.losses` is requested. This + # makes losses responsive to variable updates when executing eagerly. + self._losses.append(lambda: regularizer(variable)) if trainable: self._trainable_weights.append(variable) else: diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 1eea20deef..3e5a51eb62 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -88,6 +88,11 @@ class BaseLayerTest(test.TestCase): regularizer=regularizer) self.assertEqual(len(layer.losses), 1) + def testNoEagerActivityRegularizer(self): + with context.eager_mode(): + with self.assertRaisesRegexp(ValueError, 'activity_regularizer'): + core_layers.Dense(1, activity_regularizer=lambda *args, **kwargs: 0.) + def testGetVariable(self): with self.test_session(): diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index 7213fa1db8..fbb13bb72c 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -1232,7 +1232,8 @@ class Conv2DTranspose(Conv2D): def build(self, input_shape): if len(input_shape) != 4: - raise ValueError('Inputs should have rank 4. Received input shape: ' + str(input_shape)) + raise ValueError('Inputs should have rank 4. Received input shape: ' + + str(input_shape)) if self.data_format == 'channels_first': channel_axis = 1 else: diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 8bf831f8ba..a42282b055 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -22,11 +22,11 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" +#include "tensorflow/python/lib/core/py_util.h" #include <Python.h> namespace tensorflow { @@ -133,48 +133,6 @@ bool IsSingleNone(PyObject* obj) { return item == Py_None; } -// py.__class__.__name__ -const char* ClassName(PyObject* py) { -/* PyPy doesn't have a separate C API for old-style classes. */ -#if PY_MAJOR_VERSION < 3 && !defined(PYPY_VERSION) - if (PyClass_Check(py)) - return PyString_AS_STRING( - CHECK_NOTNULL(reinterpret_cast<PyClassObject*>(py)->cl_name)); - if (PyInstance_Check(py)) - return PyString_AS_STRING(CHECK_NOTNULL( - reinterpret_cast<PyInstanceObject*>(py)->in_class->cl_name)); -#endif - if (Py_TYPE(py) == &PyType_Type) { - return reinterpret_cast<PyTypeObject*>(py)->tp_name; - } - return Py_TYPE(py)->tp_name; -} - -string PyExcFetch() { - CHECK(PyErr_Occurred()) << "Must only call PyExcFetch after an exception."; - PyObject* ptype; - PyObject* pvalue; - PyObject* ptraceback; - PyErr_Fetch(&ptype, &pvalue, &ptraceback); - PyErr_NormalizeException(&ptype, &pvalue, &ptraceback); - string err = ClassName(ptype); - if (pvalue) { - PyObject* str = PyObject_Str(pvalue); - if (str) { -#if PY_MAJOR_VERSION < 3 - strings::StrAppend(&err, ": ", PyString_AS_STRING(str)); -#else - strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str)); -#endif - Py_DECREF(str); - } - Py_DECREF(pvalue); - } - Py_DECREF(ptype); - Py_XDECREF(ptraceback); - return err; -} - // Calls the registered py function through the trampoline. Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { *out_log_on_error = true; @@ -195,18 +153,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { if (PyErr_Occurred()) { if (PyErr_ExceptionMatches(PyExc_ValueError) || PyErr_ExceptionMatches(PyExc_TypeError)) { - return errors::InvalidArgument(PyExcFetch()); + return errors::InvalidArgument(PyExceptionFetch()); } else if (PyErr_ExceptionMatches(PyExc_StopIteration)) { *out_log_on_error = false; - return errors::OutOfRange(PyExcFetch()); + return errors::OutOfRange(PyExceptionFetch()); } else if (PyErr_ExceptionMatches(PyExc_MemoryError)) { - return errors::ResourceExhausted(PyExcFetch()); + return errors::ResourceExhausted(PyExceptionFetch()); } else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) { - return errors::Unimplemented(PyExcFetch()); + return errors::Unimplemented(PyExceptionFetch()); } else { // TODO(ebrevdo): Check if exception is an OpError and use the // OpError.error_code property to map it back in the Status. - return errors::Unknown(PyExcFetch()); + return errors::Unknown(PyExceptionFetch()); } } else { return errors::Internal("Failed to run py callback ", call->token, diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 71cb38f8fd..317bdc2e14 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/python/lib/core/numpy.h" +#include "tensorflow/python/lib/core/py_util.h" #include "tensorflow/python/lib/core/safe_ptr.h" namespace tensorflow { @@ -89,12 +90,25 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { *dtype = DT_STRING; } else if (PySequence_Check(obj)) { auto length = PySequence_Length(obj); - shape->AddDim(length); if (length > 0) { + shape->AddDim(length); obj = PySequence_GetItem(obj, 0); continue; - } else { + } else if (length == 0) { + shape->AddDim(length); *dtype = DT_INVALID; // Invalid dtype for empty tensors. + } else { + // The sequence does not have a valid length (PySequence_Length < 0). + if (PyErr_Occurred()) { + // PySequence_Length failed and set an exception. Fetch the message + // and convert it to a failed status. + return errors::InvalidArgument(PyExceptionFetch()); + } else { + // This is almost certainly dead code: PySequence_Length failed but + // did not set an exception. + return errors::InvalidArgument( + "Attempted to convert an invalid sequence to a Tensor."); + } } } else if (IsPyFloat(obj)) { *dtype = DT_DOUBLE; diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc new file mode 100644 index 0000000000..2635694e23 --- /dev/null +++ b/tensorflow/python/lib/core/py_util.cc @@ -0,0 +1,70 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/python/lib/core/py_util.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include <Python.h> + +namespace tensorflow { +namespace { + +// py.__class__.__name__ +const char* ClassName(PyObject* py) { +/* PyPy doesn't have a separate C API for old-style classes. */ +#if PY_MAJOR_VERSION < 3 && !defined(PYPY_VERSION) + if (PyClass_Check(py)) + return PyString_AS_STRING( + CHECK_NOTNULL(reinterpret_cast<PyClassObject*>(py)->cl_name)); + if (PyInstance_Check(py)) + return PyString_AS_STRING(CHECK_NOTNULL( + reinterpret_cast<PyInstanceObject*>(py)->in_class->cl_name)); +#endif + if (Py_TYPE(py) == &PyType_Type) { + return reinterpret_cast<PyTypeObject*>(py)->tp_name; + } + return Py_TYPE(py)->tp_name; +} + +} // end namespace + +string PyExceptionFetch() { + CHECK(PyErr_Occurred()) + << "Must only call PyExceptionFetch after an exception."; + PyObject* ptype; + PyObject* pvalue; + PyObject* ptraceback; + PyErr_Fetch(&ptype, &pvalue, &ptraceback); + PyErr_NormalizeException(&ptype, &pvalue, &ptraceback); + string err = ClassName(ptype); + if (pvalue) { + PyObject* str = PyObject_Str(pvalue); + if (str) { +#if PY_MAJOR_VERSION < 3 + strings::StrAppend(&err, ": ", PyString_AS_STRING(str)); +#else + strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str)); +#endif + Py_DECREF(str); + } + Py_DECREF(pvalue); + } + Py_DECREF(ptype); + Py_XDECREF(ptraceback); + return err; +} + +} // end namespace tensorflow diff --git a/tensorflow/python/lib/core/py_util.h b/tensorflow/python/lib/core/py_util.h new file mode 100644 index 0000000000..44dfe7ba21 --- /dev/null +++ b/tensorflow/python/lib/core/py_util.h @@ -0,0 +1,27 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_ +#define TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_ + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +// Fetch the exception message as a string. An exception must be set +// (PyErr_Occurred() must be true). +string PyExceptionFetch(); +} // end namespace tensorflow + +#endif // TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_ diff --git a/tensorflow/python/lib/core/safe_ptr.cc b/tensorflow/python/lib/core/safe_ptr.cc index 456ea3348b..ce34b6d004 100644 --- a/tensorflow/python/lib/core/safe_ptr.cc +++ b/tensorflow/python/lib/core/safe_ptr.cc @@ -16,25 +16,21 @@ limitations under the License. #include "tensorflow/python/lib/core/safe_ptr.h" namespace tensorflow { -namespace { -inline void Py_DECREF_wrapper(PyObject* o) { Py_DECREF(o); } - -} // namespace - -Safe_PyObjectPtr make_safe(PyObject* o) { - return Safe_PyObjectPtr(o, Py_DECREF_wrapper); +Safe_PyObjectPtr make_safe(PyObject* object) { + return Safe_PyObjectPtr(object); } Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) { - return Safe_TF_TensorPtr(tensor, TF_DeleteTensor); + return Safe_TF_TensorPtr(tensor); } Safe_TFE_TensorHandlePtr make_safe(TFE_TensorHandle* handle) { - return Safe_TFE_TensorHandlePtr(handle, TFE_DeleteTensorHandle); + return Safe_TFE_TensorHandlePtr(handle); } Safe_TF_StatusPtr make_safe(TF_Status* status) { - return Safe_TF_StatusPtr(status, TF_DeleteStatus); + return Safe_TF_StatusPtr(status); } + } // namespace tensorflow diff --git a/tensorflow/python/lib/core/safe_ptr.h b/tensorflow/python/lib/core/safe_ptr.h index 70cd2fdf6c..80db840aeb 100644 --- a/tensorflow/python/lib/core/safe_ptr.h +++ b/tensorflow/python/lib/core/safe_ptr.h @@ -17,39 +17,51 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ #include <memory> -#include <Python.h> +#include <Python.h> #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" namespace tensorflow { +namespace detail { + +struct PyDecrefDeleter { + void operator()(PyObject* p) const { Py_DECREF(p); } +}; + +struct TFTensorDeleter { + void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); } +}; + +struct TFETensorHandleDeleter { + void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); } +}; + +struct TFStatusDeleter { + void operator()(TF_Status* p) const { TF_DeleteStatus(p); } +}; + +} // namespace detail // Safe container for an owned PyObject. On destruction, the reference count of // the contained object will be decremented. -typedef void (*Py_DECREF_wrapper_type)(PyObject*); -typedef std::unique_ptr<PyObject, Py_DECREF_wrapper_type> Safe_PyObjectPtr; +using Safe_PyObjectPtr = std::unique_ptr<PyObject, detail::PyDecrefDeleter>; Safe_PyObjectPtr make_safe(PyObject* o); // Safe containers for an owned TF_Tensor. On destruction, the tensor will be // deleted by TF_DeleteTensor. -// Note: can't use decltype(&TF_DeleteTensor) due to SWIG -typedef void (*TF_DeleteTensor_type)(TF_Tensor*); -typedef std::unique_ptr<TF_Tensor, TF_DeleteTensor_type> Safe_TF_TensorPtr; +using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, detail::TFTensorDeleter>; Safe_TF_TensorPtr make_safe(TF_Tensor* tensor); // Safe containers for an owned TFE_TensorHandle. On destruction, the handle -// will be deleted by TFE_DeleteTensorHandle. Note: can't use -// decltype(&TFE_DeleteTensorHandle) due to SWIG -typedef void (*TFE_DeleteTensorHandle_type)(TFE_TensorHandle*); -typedef std::unique_ptr<TFE_TensorHandle, TFE_DeleteTensorHandle_type> - Safe_TFE_TensorHandlePtr; +// will be deleted by TFE_DeleteTensorHandle. +using Safe_TFE_TensorHandlePtr = + std::unique_ptr<TFE_TensorHandle, detail::TFETensorHandleDeleter>; Safe_TFE_TensorHandlePtr make_safe(TFE_TensorHandle* handle); // Safe containers for an owned TF_Status. On destruction, the handle -// will be deleted by TF_DeleteStatus. Note: can't use -// decltype(&TF_DeleteStatus) due to SWIG -typedef void (*TF_DeleteStatus_type)(TF_Status*); -typedef std::unique_ptr<TF_Status, TF_DeleteStatus_type> Safe_TF_StatusPtr; +// will be deleted by TF_DeleteStatus. +using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, detail::TFStatusDeleter>; Safe_TF_StatusPtr make_safe(TF_Status* status); } // namespace tensorflow diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 4b406ba840..8cd535aa0b 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -41,33 +41,48 @@ def _Conv2DBackpropInputGrad(op, grad): Returns: the gradients w.r.t. the input and the filter """ - return [None, - nn_ops.conv2d_backprop_filter(grad, array_ops.shape(op.inputs[1]), - op.inputs[2], op.get_attr("strides"), - op.get_attr("padding"), - op.get_attr("use_cudnn_on_gpu"), - op.get_attr("data_format")), - nn_ops.conv2d(grad, op.inputs[1], op.get_attr("strides"), - op.get_attr("padding"), op.get_attr("use_cudnn_on_gpu"), - op.get_attr("data_format"))] + return [ + None, + nn_ops.conv2d_backprop_filter( + grad, + array_ops.shape(op.inputs[1]), + op.inputs[2], + dilations=op.get_attr("dilations"), + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), + data_format=op.get_attr("data_format")), + nn_ops.conv2d( + grad, + op.inputs[1], + dilations=op.get_attr("dilations"), + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), + data_format=op.get_attr("data_format")) + ] @ops.RegisterGradient("Conv2DBackpropFilter") def _Conv2DBackpropFilterGrad(op, grad): return [ nn_ops.conv2d_backprop_input( - array_ops.shape(op.inputs[0]), grad, op.inputs[2], - op.get_attr("strides"), - op.get_attr("padding"), - op.get_attr("use_cudnn_on_gpu"), - op.get_attr("data_format")), - None, + array_ops.shape(op.inputs[0]), + grad, + op.inputs[2], + dilations=op.get_attr("dilations"), + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), + data_format=op.get_attr("data_format")), None, nn_ops.conv2d( - op.inputs[0], grad, - op.get_attr("strides"), - op.get_attr("padding"), - op.get_attr("use_cudnn_on_gpu"), - op.get_attr("data_format")) + op.inputs[0], + grad, + dilations=op.get_attr("dilations"), + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), + data_format=op.get_attr("data_format")) ] @@ -466,25 +481,32 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _): @ops.RegisterGradient("Conv2D") def _Conv2DGrad(op, grad): + dilations = op.get_attr("dilations") strides = op.get_attr("strides") padding = op.get_attr("padding") use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu") data_format = op.get_attr("data_format") shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]]) - return [nn_ops.conv2d_backprop_input(shape_0, - op.inputs[1], - grad, - strides, - padding, - use_cudnn_on_gpu, - data_format), - nn_ops.conv2d_backprop_filter(op.inputs[0], - shape_1, - grad, - strides, - padding, - use_cudnn_on_gpu, - data_format)] + return [ + nn_ops.conv2d_backprop_input( + shape_0, + op.inputs[1], + grad, + dilations=dilations, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format), + nn_ops.conv2d_backprop_filter( + op.inputs[0], + shape_1, + grad, + dilations=dilations, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format) + ] @ops.RegisterGradient("DepthwiseConv2dNative") diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index ec7b9372ca..b3c0a22efc 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1205,13 +1205,14 @@ def conv2d_transpose(value, raise ValueError("padding must be either VALID or SAME:" " {}".format(padding)) - return gen_nn_ops.conv2d_backprop_input(input_sizes=output_shape_, - filter=filter, - out_backprop=value, - strides=strides, - padding=padding, - data_format=data_format, - name=name) + return gen_nn_ops.conv2d_backprop_input( + input_sizes=output_shape_, + filter=filter, + out_backprop=value, + strides=strides, + padding=padding, + data_format=data_format, + name=name) def atrous_conv2d_transpose(value, @@ -1343,12 +1344,13 @@ def atrous_conv2d_transpose(value, (in_width + pad_right_extra) // rate, output_shape[3]] - value = gen_nn_ops.conv2d_backprop_input(input_sizes=input_sizes, - filter=filters, - out_backprop=value, - strides=[1, 1, 1, 1], - padding="VALID", - data_format="NHWC") + value = gen_nn_ops.conv2d_backprop_input( + input_sizes=input_sizes, + filter=filters, + out_backprop=value, + strides=[1, 1, 1, 1], + padding="VALID", + data_format="NHWC") # The crops argument to batch_to_space includes both padding components. batch_to_space_crop = [[pad_top, pad_bottom + pad_bottom_extra], diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 52fb5131cf..afaff8ca41 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -316,7 +316,7 @@ def random_crop(value, size, seed=None, name=None): return array_ops.slice(value, offset, size, name=name) -def multinomial(logits, num_samples, seed=None, name=None): +def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None): """Draws samples from a multinomial distribution. Example: @@ -336,6 +336,7 @@ def multinomial(logits, num_samples, seed=None, name=None): @{tf.set_random_seed} for behavior. name: Optional name for the operation. + output_dtype: integer type to use for the output. Defaults to int64. Returns: The drawn samples of shape `[batch_size, num_samples]`. @@ -344,7 +345,7 @@ def multinomial(logits, num_samples, seed=None, name=None): logits = ops.convert_to_tensor(logits, name="logits") seed1, seed2 = random_seed.get_seed(seed) return gen_random_ops.multinomial( - logits, num_samples, seed=seed1, seed2=seed2) + logits, num_samples, seed=seed1, seed2=seed2, output_dtype=output_dtype) ops.NotDifferentiable("Multinomial") diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 343e38f960..652bfa1ebc 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -887,26 +887,19 @@ def _ReadGrad(_, grad): def _GatherGrad(op, grad): """Gradient for gather op.""" # Build appropriately shaped IndexedSlices - # Walk graph back until the original handle is found. - # TODO(apassos): more robust way of getting the shape. - # TODO(apassos): implement this for EAGER mode. - if context.in_eager_mode(): - dense_shape = gen_resource_variable_ops.variable_shape(op.inputs[0]) - return (ops.IndexedSlices(grad, - op.inputs[1], - dense_shape=dense_shape), - None) handle = op.inputs[0] - while handle.op.type != "VarHandleOp": - handle = handle.op.inputs[0] - params_shape = ops.convert_to_tensor( - tensor_shape.TensorShape(handle.op.get_attr("shape"))) indices = op.inputs[1] + if context.in_graph_mode(): + # Walk graph back until the original handle is found. + # TODO(apassos): implement this for EAGER mode. + while handle.op.type != "VarHandleOp": + handle = handle.op.inputs[0] + params_shape = gen_resource_variable_ops.variable_shape(handle) size = array_ops.expand_dims(array_ops.size(indices), 0) values_shape = array_ops.concat([size, params_shape[1:]], 0) values = array_ops.reshape(grad, values_shape) indices = array_ops.reshape(indices, size) - return [ops.IndexedSlices(values, indices, params_shape), None] + return (ops.IndexedSlices(values, indices, params_shape), None) def _to_proto_fn(v, export_scope=None): diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index cdfe9e1c1e..9bdc124c83 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -1437,10 +1437,47 @@ def serialize_many_sparse(sp_input, name=None): def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None): """Deserialize `SparseTensor` objects. - The input is expected to have shape [d_1, ..., d_m, 3], where the last - dimension stores a serialized `SparseTensor`. The method deserializes - all input `SparseTensor`s, concatenates them into a single tensor, and - reshapes the sparse tensor to preserve the structure of the input. + The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where + the last dimension stores serialized `SparseTensor` objects and the other N + dimensions (N >= 0) correspond to a batch. The ranks of the original + `SparseTensor` objects must all match. When the final `SparseTensor` is + created, its rank is the rank of the incoming `SparseTensor` objects plus N; + the sparse tensors have been concatenated along new dimensions, one for each + batch. + + The output `SparseTensor` object's shape values for the original dimensions + are the max across the input `SparseTensor` objects' shape values for the + corresponding dimensions. The new dimensions match the size of the batch. + + The input `SparseTensor` objects' indices are assumed ordered in + standard lexicographic order. If this is not the case, after this + step run `SparseReorder` to restore index ordering. + + For example, if the serialized input is a `[2 x 3]` matrix representing two + original `SparseTensor` objects: + + index = [ 0] + [10] + [20] + values = [1, 2, 3] + shape = [50] + + and + + index = [ 2] + [10] + values = [4, 5] + shape = [30] + + then the final deserialized `SparseTensor` will be: + + index = [0 0] + [0 10] + [0 20] + [1 2] + [1 10] + values = [1, 2, 3, 4, 5] + shape = [2 50] Args: serialized_sparse: The serialized `SparseTensor` objects. diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index dfc657893c..dee495f78f 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -347,5 +347,71 @@ def scatter_update(ref, indices, updates, use_locking=True, name=None): if ref.dtype._is_ref_dtype: return gen_state_ops.scatter_update(ref, indices, updates, use_locking=use_locking, name=name) - return gen_resource_variable_ops.resource_scatter_update( - ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), name=name) + with ops.control_dependencies( + [gen_resource_variable_ops.resource_scatter_update( + ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), + name=name)]): + return ref.read_value() + + +def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): + r"""Applies sparse `updates` to individual values or slices in a Variable. + + `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + + `indices` must be integer tensor, containing indices into `ref`. + It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + + The innermost dimension of `indices` (with length `K`) corresponds to + indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th + dimension of `ref`. + + `updates` is `Tensor` of rank `Q-1+P-K` with shape: + + ``` + [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. + ``` + + For example, say we want to update 4 scattered elements to a rank-1 tensor to + 8 elements. In Python, that update would look like this: + + ```python + ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1] ,[7]]) + updates = tf.constant([9, 10, 11, 12]) + update = tf.scatter_nd_update(ref, indices, updates) + with tf.Session() as sess: + print sess.run(update) + ``` + + The resulting update to ref would look like this: + + [1, 11, 3, 10, 9, 6, 7, 12] + + See @{tf.scatter_nd} for more details about how to make updates to + slices. + + Args: + ref: A Variable. + indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. + A Tensor. Must be one of the following types: int32, int64. + A tensor of indices into ref. + updates: A `Tensor`. Must have the same type as `ref`. + A Tensor. Must have the same type as ref. A tensor of updated + values to add to ref. + use_locking: An optional `bool`. Defaults to `True`. + An optional bool. Defaults to True. If True, the assignment will + be protected by a lock; otherwise the behavior is undefined, + but may exhibit less contention. + name: A name for the operation (optional). + + Returns: + The value of the variable after the update. + """ + if ref.dtype._is_ref_dtype: + return gen_state_ops.scatter_nd_update( + ref, indices, updates, use_locking, name) + with ops.control_dependencies([gen_state_ops.resource_scatter_nd_update( + ref.handle, indices, ops.convert_to_tensor(updates, dtype=ref.dtype), + use_locking, name)]): + return ref.read_value() diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 98578b799a..07796b28d9 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -308,6 +308,12 @@ class Template(object): return name if name[-1] == "/" else name + "/" @property + def variables(self): + """Returns the list of global and local variables created by the Template. + """ + return self.global_variables + self.local_variables + + @property def trainable_variables(self): """Returns the list of trainable variables created by the Template.""" if self._variables_created: @@ -317,6 +323,14 @@ class Template(object): return [] @property + def non_trainable_variables(self): + """Returns the list of non-trainable variables created by the Template.""" + # TODO(apassos) Make sure it matches Eager when using local variables. + global_variables = self.global_variables + trainable_variables = set(self.trainable_variables) + return [x for x in global_variables if x not in trainable_variables] + + @property def global_variables(self): """Returns the list of global variables created by the Template.""" if self._variables_created: @@ -335,6 +349,21 @@ class Template(object): return [] @property + def weights(self): + """List of weights/variables created by the Template.""" + return self.variables + + @property + def trainable_weights(self): + """List of trainable weights/variables created by the Template.""" + return self.trainable_variables + + @property + def non_trainable_weights(self): + """List of non-trainable weights/variables created by the Template.""" + return self.non_trainable_variables + + @property @deprecated( "2017-02-21", "The .var_scope property is deprecated. Please change your " "code to use the .variable_scope property") @@ -501,7 +530,7 @@ class EagerTemplate(Template): @property def variables(self): - """Returns the list of trainable variables created by the Template.""" + """Returns the list of variables created by the Template.""" # Currently there is no local variable in Eager mode. return self._eager_variable_store.variables() @@ -512,6 +541,12 @@ class EagerTemplate(Template): return self._eager_variable_store.trainable_variables() @property + def non_trainable_variables(self): + """Returns the list of non-trainable variables created by the Template.""" + # Currently there is no local variable in Eager mode. + return self._eager_variable_store.non_trainable_variables() + + @property def global_variables(self): """Returns the list of global variables created by the Template.""" # Currently there is no local variable in Eager mode. diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py index e9a36ae75d..abd6f3d855 100644 --- a/tensorflow/python/platform/flags.py +++ b/tensorflow/python/platform/flags.py @@ -18,5 +18,53 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import logging as _logging + # go/tf-wildcard-import from absl.flags import * # pylint: disable=wildcard-import +import six as _six + +from tensorflow.python.util import tf_decorator + + +# Since we wrap absl.flags DEFINE functions, we need to declare this module +# does not affect key flags. +disclaim_key_flags() # pylint: disable=undefined-variable + + +_RENAMED_ARGUMENTS = { + 'flag_name': 'name', + 'default_value': 'default', + 'docstring': 'help', +} + + +def _wrap_define_function(original_function): + """Wraps absl.flags's define functions so tf.flags accepts old names.""" + + def wrapper(*args, **kwargs): + """Wrapper function that turns old keyword names to new ones.""" + has_old_names = False + for old_name, new_name in _six.iteritems(_RENAMED_ARGUMENTS): + if old_name in kwargs: + has_old_names = True + value = kwargs.pop(old_name) + kwargs[new_name] = value + if has_old_names: + _logging.warning( + 'Use of the keyword argument names (flag_name, default_value, ' + 'docstring) is deprecated, please use (name, default, help) instead.') + return original_function(*args, **kwargs) + + return tf_decorator.make_decorator(original_function, wrapper) + + +# pylint: disable=invalid-name,used-before-assignment +# absl.flags APIs use `default` as the name of the default value argument. +# Allow the following functions continue to accept `default_value`. +DEFINE_string = _wrap_define_function(DEFINE_string) +DEFINE_boolean = _wrap_define_function(DEFINE_boolean) +DEFINE_bool = DEFINE_boolean +DEFINE_float = _wrap_define_function(DEFINE_float) +DEFINE_integer = _wrap_define_function(DEFINE_integer) +# pylint: enable=invalid-name,used-before-assignment diff --git a/tensorflow/python/platform/flags_test.py b/tensorflow/python/platform/flags_test.py index 23060e17d2..e8200142dd 100644 --- a/tensorflow/python/platform/flags_test.py +++ b/tensorflow/python/platform/flags_test.py @@ -24,11 +24,50 @@ from absl import flags as absl_flags from tensorflow.python.platform import flags +flags.DEFINE_string( + flag_name='old_string', default_value='default', docstring='docstring') +flags.DEFINE_string( + name='new_string', default='default', help='docstring') +flags.DEFINE_integer( + flag_name='old_integer', default_value=1, docstring='docstring') +flags.DEFINE_integer( + name='new_integer', default=1, help='docstring') +flags.DEFINE_float( + flag_name='old_float', default_value=1.5, docstring='docstring') +flags.DEFINE_float( + name='new_float', default=1.5, help='docstring') +flags.DEFINE_bool( + flag_name='old_bool', default_value=True, docstring='docstring') +flags.DEFINE_bool( + name='new_bool', default=True, help='docstring') +flags.DEFINE_boolean( + flag_name='old_boolean', default_value=False, docstring='docstring') +flags.DEFINE_boolean( + name='new_boolean', default=False, help='docstring') + + class FlagsTest(unittest.TestCase): def test_global_flags_object(self): self.assertIs(flags.FLAGS, absl_flags.FLAGS) + def test_keyword_arguments(self): + test_cases = ( + ('old_string', 'default'), + ('new_string', 'default'), + ('old_integer', 1), + ('new_integer', 1), + ('old_float', 1.5), + ('new_float', 1.5), + ('old_bool', True), + ('new_bool', True), + ('old_boolean', False), + ('new_boolean', False), + ) + for flag_name, default_value in test_cases: + self.assertEqual(default_value, absl_flags.FLAGS[flag_name].default) + self.assertEqual('docstring', absl_flags.FLAGS[flag_name].help) + -if __name__ == "__main__": +if __name__ == '__main__': unittest.main() diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py index 26fb99efe6..ccfb9aac53 100644 --- a/tensorflow/python/profiler/model_analyzer_test.py +++ b/tensorflow/python/profiler/model_analyzer_test.py @@ -23,12 +23,15 @@ import os import random import re +import numpy as np + from tensorflow.core.profiler import profile_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -346,8 +349,8 @@ class PrintModelAnalysisTest(test.TestCase): with gfile.Open(outfile, 'r') as f: # pylint: disable=line-too-long self.assertEqual( - 'nodename|requestedbytes|peakbytes|residualbytes|outputbytes|totalexecutiontime|acceleratorexecutiontime|cpuexecutiontime|#parameters|opoccurrence(run|defined)|inputshapes\nConst0B(0', - f.read().replace('\t', '').replace(' ', '')[0:180]) + 'nodename|requestedbytes|peakbytes|residualbytes|outputbytes|totalexecutiontime|acceleratorexecutiontime|cpuexecutiontime|#parameters|opoccurrence(run|defined)|inputshapes', + f.read().replace('\t', '').replace(' ', '')[0:170]) # pylint: enable=line-too-long total_children = 0 @@ -694,6 +697,39 @@ class PrintModelAnalysisTest(test.TestCase): exception_str) self.assertTrue(mat is None) + def testTrackPersistentBytes(self): + ops.reset_default_graph() + a = array_ops.constant(np.ones((100, 100))) + b = array_ops.constant(np.ones((100, 100))) + c = a * b + + with session.Session() as sess: + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + sess.run(c, options=run_options, run_metadata=run_metadata) + + options = option_builder.ProfileOptionBuilder.time_and_memory() + options['min_bytes'] = 0 + options['select'] = ('bytes', 'peak_bytes', 'output_bytes', + 'residual_bytes') + ret = model_analyzer.profile( + sess.graph, run_meta=run_metadata, cmd='scope', options=options) + + run_metadata = config_pb2.RunMetadata() + sess.run(c, options=run_options, run_metadata=run_metadata) + ret2 = model_analyzer.profile( + sess.graph, run_meta=run_metadata, cmd='scope', options=options) + + n = lib.SearchTFProfNode(ret, 'mul') + n2 = lib.SearchTFProfNode(ret2, 'mul') + self.assertGreater(n.peak_bytes, 0) + self.assertGreater(n.output_bytes, 0) + self.assertGreater(n.residual_bytes, 0) + self.assertEqual(n.peak_bytes, n2.peak_bytes) + self.assertEqual(n.output_bytes, n2.output_bytes) + self.assertEqual(n.residual_bytes, n2.residual_bytes) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 82b154164e..82750e9e49 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -18,6 +18,7 @@ limitations under the License. %rename("%s") TFE_NewContext; %rename("%s") TFE_DeleteContext; %rename("%s") TFE_ContextListDevices; +%rename("%s") TFE_ContextAddFunction; %rename("%s") TFE_ContextAddFunctionDef; %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; @@ -149,7 +150,7 @@ limitations under the License. } $1 = &temp; $1->resize(PyInt_AsLong($input), nullptr); -} +} // Create new Status object. %typemap(in, numinputs=0) TF_Status *out_status { diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index 7268b3abc9..6865513b0e 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -234,23 +234,38 @@ class MomentumOptimizerTest(test.TestCase): self.assertAllClose(var0_np, var0.eval()) self.assertAllClose(var1_np, var1.eval()) + @test_util.run_in_graph_and_eager_modes(reset_test=True) def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + + # pylint: disable=cell-var-from-loop + def loss(): x = constant_op.constant([[4.0], [5.0]], dtype=dtype) pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) - loss = pred * pred - sgd_op = momentum_lib.MomentumOptimizer( - learning_rate=1.0, momentum=0.0).minimize(loss) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType( - [[-111, -138]], var0.eval()) + return pred * pred + # pylint: enable=cell-var-from-loop + + opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0) + sgd_op = opt.minimize(loss if context.in_eager_mode() else loss()) + self.evaluate(variables.global_variables_initializer()) + # Run 1 step of sgd + self.evaluate(sgd_op) + # Validate updated params + self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0)) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testMinimizeWith2DIndiciesForEmbeddingLookup(self): + var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2])) + + def loss(): + return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]])) + + opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0) + sgd_op = opt.minimize(loss if context.in_eager_mode() else loss()) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(sgd_op) + self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0)) def testTensorLearningRateAndMomentum(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index e931555470..f1cb81981a 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -52,7 +52,6 @@ _PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError) USE_DEFAULT = object() -# TODO(touts): Share that with the Supervisor. class Scaffold(object): """Structure to create or gather pieces commonly needed to train a model. diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index b7f1297b8f..74ee1e5fa8 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -774,9 +774,13 @@ class SaveRestoreShardedTest(test.TestCase): with sess.graph.device("/cpu:0"): v0 = variables.Variable(111, name="v0") t0 = saver_test_utils.CheckpointedOp(name="t0") - save = saver_module.Saver({"v0": v0, "t0": t0.saveable}, - write_version=self._WRITE_VERSION, - sharded=True) + save = saver_module.Saver( + { + "v0": v0, + "t0": t0.saveable + }, + write_version=self._WRITE_VERSION, + sharded=True) variables.global_variables_initializer().run() t0.insert("k11", 33.0).run() self.assertEqual(111, v0.eval()) @@ -794,9 +798,13 @@ class SaveRestoreShardedTest(test.TestCase): with sess.graph.device("/cpu:0"): v1 = variables.Variable(222) t1 = saver_test_utils.CheckpointedOp(name="t1") - save = saver_module.Saver({"v1": v1, "t1": t1.saveable}, - write_version=self._WRITE_VERSION, - sharded=True) + save = saver_module.Saver( + { + "v1": v1, + "t1": t1.saveable + }, + write_version=self._WRITE_VERSION, + sharded=True) variables.global_variables_initializer().run() t1.insert("k22", 44.0).run() self.assertEqual(222, v1.eval()) diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index a634a842b6..e4514aaea2 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -36,11 +36,15 @@ from tensorflow.python.training import coordinator from tensorflow.python.training import saver as saver_mod from tensorflow.python.training import session_manager as session_manager_mod from tensorflow.python.training import training_util +from tensorflow.python.util import deprecation class Supervisor(object): """A training helper that checkpoints models and computes summaries. + This class is deprecated. Please use + ${tf.train.MonitoredTrainingSession} instead. + The Supervisor is a small wrapper around a `Coordinator`, a `Saver`, and a `SessionManager` that takes care of common needs of TensorFlow training programs. @@ -198,6 +202,8 @@ class Supervisor(object): # the default behavior should be used. USE_DEFAULT = 0 + @deprecation.deprecated(None, + "Please switch to tf.train.MonitoredTrainingSession") def __init__(self, graph=None, ready_op=USE_DEFAULT, diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index f5802d9359..5c066e2bef 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -456,9 +456,9 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): if set(input_tree) != set(shallow_tree): raise ValueError( "The two structures don't have the same keys. Input " - "structure has keys %s, while shallow structure has keys %s." - % (list(_six.iterkeys(input_tree)), - list(_six.iterkeys(shallow_tree)))) + "structure has keys %s, while shallow structure has keys %s." % + (list(_six.iterkeys(input_tree)), + list(_six.iterkeys(shallow_tree)))) input_tree = list(_six.iteritems(input_tree)) shallow_tree = list(_six.iteritems(shallow_tree)) diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 26aeaeec19..3d9e9f9684 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -388,8 +388,9 @@ class NestTest(test.TestCase): inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} expected_message = ( - "The two structures don't have the same keys. Input " - "structure has keys \['c'\], while shallow structure has keys \['d'\].") + r"The two structures don't have the same keys. Input " + r"structure has keys \['c'\], while shallow structure has " + r"keys \['d'\].") with self.assertRaisesRegexp(ValueError, expected_message): nest.assert_shallow_structure(inp_ab2, inp_ab1) @@ -438,8 +439,7 @@ class NestTest(test.TestCase): input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) - shallow_tree = collections.OrderedDict([("a", 0), - ("c", {"d": 3, "e": 1})]) + shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 8d392fb36d..76ef59484f 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -167,7 +167,19 @@ WIN_COPTS = [ ] # LINT.IfChange -def tf_copts(): +def tf_copts(android_optimization_level_override="-O2"): + # For compatibility reasons, android_optimization_level_override + # is currently only being set for Android. + # To clear this value, and allow the CROSSTOOL default + # to be used, pass android_optimization_level_override=None + android_copts = [ + "-std=c++11", + "-DTF_LEAN_BINARY", + "-Wno-narrowing", + "-fomit-frame-pointer", + ] + if android_optimization_level_override: + android_copts.append(android_optimization_level_override) return ( if_not_windows([ "-DEIGEN_AVOID_STL_ARRAY", @@ -180,13 +192,7 @@ def tf_copts(): + if_android_arm(["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + select({ - clean_dep("//tensorflow:android"): [ - "-std=c++11", - "-DTF_LEAN_BINARY", - "-O2", - "-Wno-narrowing", - "-fomit-frame-pointer", - ], + clean_dep("//tensorflow:android"): android_copts, clean_dep("//tensorflow:darwin"): [], clean_dep("//tensorflow:windows"): WIN_COPTS, clean_dep("//tensorflow:windows_msvc"): WIN_COPTS, diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt index ebd9c079b5..d920fef770 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt @@ -54,15 +54,15 @@ tf_module { } member_method { name: "conv2d" - argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'None\'], " + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " } member_method { name: "conv2d_backprop_filter" - argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'None\'], " + argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " } member_method { name: "conv2d_backprop_input" - argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'None\'], " + argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " } member_method { name: "conv2d_transpose" @@ -70,11 +70,11 @@ tf_module { } member_method { name: "conv3d" - argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'None\'], " + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'[1, 1, 1, 1, 1]\', \'None\'], " } member_method { name: "conv3d_backprop_filter_v2" - argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'None\'], " + argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'[1, 1, 1, 1, 1]\', \'None\'], " } member_method { name: "conv3d_transpose" @@ -106,15 +106,15 @@ tf_module { } member_method { name: "depthwise_conv2d_native" - argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], " + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " } member_method { name: "depthwise_conv2d_native_backprop_filter" - argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], " + argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " } member_method { name: "depthwise_conv2d_native_backprop_input" - argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], " + argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " } member_method { name: "dilation2d" @@ -234,7 +234,7 @@ tf_module { } member_method { name: "quantized_conv2d" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'strides\', \'padding\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'qint32\'>\", \'None\'], " + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'strides\', \'padding\', \'out_type\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'qint32\'>\", \'[1, 1, 1, 1]\', \'None\'], " } member_method { name: "quantized_max_pool" diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 0edd4153d7..57573d5024 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -1394,7 +1394,7 @@ tf_module { } member_method { name: "multinomial" - argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'name\', \'output_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "multiply" diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 404a9a6b62..4021d794b6 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -99,7 +99,8 @@ do_pylint() { "^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\ "^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\ "^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\ -"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition" +"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\ +"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned" echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\"" diff --git a/tensorflow/tools/dist_test/python/census_widendeep.py b/tensorflow/tools/dist_test/python/census_widendeep.py index 6f578d6f67..8feb5386e9 100644 --- a/tensorflow/tools/dist_test/python/census_widendeep.py +++ b/tensorflow/tools/dist_test/python/census_widendeep.py @@ -263,8 +263,7 @@ if __name__ == "__main__": "--data_dir", type=str, default="/tmp/census-data", - help="Directory for storing the census data" - ) + help="Directory for storing the census data") parser.add_argument( "--model_dir", type=str, diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index c18f20910a..3852b251d9 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -33,7 +33,12 @@ _VERSION = '1.4.0' REQUIRED_PACKAGES = [ 'absl-py', - 'enum34 >= 1.1.6', + # weakref.finalize introduced in Python 3.4 + 'backports.weakref >= 1.0rc1; python_version < "3.4"', + # enum module introduced in Python 3.4 + 'enum34 >= 1.1.6; python_version < "3.4"', + # Needed for unittest.mock in Python 2 + 'mock >= 2.0.0; python_version < "3.0"', 'numpy >= 1.12.1', 'six >= 1.10.0', 'protobuf >= 3.4.0', @@ -52,8 +57,6 @@ if sys.version_info.major == 3: REQUIRED_PACKAGES.append('wheel >= 0.26') else: REQUIRED_PACKAGES.append('wheel') - # mock comes with unittest.mock for python3, need to install for python2 - REQUIRED_PACKAGES.append('mock >= 2.0.0') # tf-nightly should depend on tb-nightly if 'tf_nightly' in project_name: @@ -62,10 +65,6 @@ if 'tf_nightly' in project_name: REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.5.0a0, < 1.6.0a0' break -# weakref.finalize was introduced in Python 3.4 -if sys.version_info < (3, 4): - REQUIRED_PACKAGES.append('backports.weakref >= 1.0rc1') - # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ 'freeze_graph = tensorflow.python.tools.freeze_graph:main', diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 6b13271002..c2256b6313 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -57,33 +57,6 @@ def check_version(bazel_version): fail("\nCurrent Bazel version is {}, expected at least {}\n".format( native.bazel_version, bazel_version)) -def _repos_are_siblings(): - return Label("@foo//bar").workspace_root.startswith("../") - -# Temporary workaround to support including TensorFlow as a submodule until this -# use-case is supported in the next Bazel release. -def _temp_workaround_http_archive_impl(repo_ctx): - repo_ctx.template("BUILD", repo_ctx.attr.build_file, { - "%prefix%": ".." if _repos_are_siblings() else "external", - "%ws%": repo_ctx.attr.repository - }, False) - repo_ctx.download_and_extract(repo_ctx.attr.urls, "", repo_ctx.attr.sha256, - "", repo_ctx.attr.strip_prefix) - if repo_ctx.attr.patch_file != None: - _apply_patch(repo_ctx, repo_ctx.attr.patch_file) - -temp_workaround_http_archive = repository_rule( - attrs = { - "build_file": attr.label(), - "repository": attr.string(), - "patch_file": attr.label(default = None), - "urls": attr.string_list(default = []), - "sha256": attr.string(default = ""), - "strip_prefix": attr.string(default = ""), - }, - implementation = _temp_workaround_http_archive_impl, -) - # Executes specified command with arguments and calls 'fail' if it exited with # non-zero code def _execute_and_check_ret_code(repo_ctx, cmd_and_args): @@ -121,8 +94,6 @@ def _patched_http_archive_impl(repo_ctx): patched_http_archive = repository_rule( attrs = { "patch_file": attr.label(), - "build_file": attr.label(), - "repository": attr.string(), "urls": attr.string_list(default = []), "sha256": attr.string(default = ""), "strip_prefix": attr.string(default = ""), @@ -157,7 +128,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "57ba56c4c243f403ff78f417ff854ef50b9eddf4a610a917b7c95e7fa8553a4b", strip_prefix = "mklml_lnx_2018.0.20170720", build_file = str(Label("//third_party/mkl:mkl.BUILD")), - repository = tf_repo_name, ) if path_prefix: @@ -292,7 +262,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:nasm.BUILD")), ) - temp_workaround_http_archive( + native.new_http_archive( name = "jpeg", urls = [ "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", @@ -301,7 +271,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "c15a9607892113946379ccea3ca8b85018301b200754f209453ab21674268e77", strip_prefix = "libjpeg-turbo-1.5.1", build_file = str(Label("//third_party/jpeg:jpeg.BUILD")), - repository = tf_repo_name, ) native.new_http_archive( @@ -447,11 +416,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "nsync", urls = [ - "https://mirror.bazel.build/github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz", - "https://github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz", + "https://mirror.bazel.build/github.com/google/nsync/archive/8502189abfa44c249c01c2cad64e6ed660a9a668.tar.gz", + "https://github.com/google/nsync/archive/8502189abfa44c249c01c2cad64e6ed660a9a668.tar.gz", ], - sha256 = "e3bd4555415ace511338fc27e595351738eea4e9006f1612b76c82914770716b", - strip_prefix = "nsync-93815892dddafe9146a5f7e7042281d59d0f4323", + sha256 = "51f81ff4202bbb820cdbedc061bd2eb6765f2b5c06489e7a8694bedac329e8f8", + strip_prefix = "nsync-8502189abfa44c249c01c2cad64e6ed660a9a668", ) native.http_archive( @@ -502,7 +471,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:swig.BUILD")), ) - temp_workaround_http_archive( + native.new_http_archive( name = "curl", sha256 = "ff3e80c1ca6a068428726cd7dd19037a47cc538ce58ef61c59587191039b2ca6", urls = [ @@ -511,7 +480,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], strip_prefix = "curl-7.49.1", build_file = str(Label("//third_party:curl.BUILD")), - repository = tf_repo_name ) # grpc expects //external:protobuf_clib and //external:protobuf_compiler @@ -575,16 +543,15 @@ def tf_workspace(path_prefix="", tf_repo_name=""): # TODO(phawkins): currently, this rule uses an unofficial LLVM mirror. # Switch to an official source of snapshots if/when possible. - temp_workaround_http_archive( + native.new_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/8d26b8bee4d8e7230870a600bc968c7ee8cf6f67.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/8d26b8bee4d8e7230870a600bc968c7ee8cf6f67.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/9ab4c272cb604a7f947865428c4ef2169fee2100.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/9ab4c272cb604a7f947865428c4ef2169fee2100.tar.gz", ], - sha256 = "ff5ddbe5af5e264426c8d489e7fddfc5ad7e0975f19cefe9db8c0a5d0faeb23e", - strip_prefix = "llvm-8d26b8bee4d8e7230870a600bc968c7ee8cf6f67", + sha256 = "1b1b7d3800a94ca2302e3dd670dbe84238749583027883784b55297059d83da8", + strip_prefix = "llvm-9ab4c272cb604a7f947865428c4ef2169fee2100", build_file = str(Label("//third_party/llvm:llvm.BUILD")), - repository = tf_repo_name, ) native.new_http_archive( @@ -650,7 +617,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party/fft2d:fft2d.BUILD")), ) - temp_workaround_http_archive( + native.new_http_archive( name = "snappy", urls = [ "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz", @@ -659,10 +626,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "2f7504c73d85bac842e893340333be8cb8561710642fc9562fccdd9d2c3fcc94", strip_prefix = "snappy-1.1.4", build_file = str(Label("//third_party:snappy.BUILD")), - repository = tf_repo_name, ) - temp_workaround_http_archive( + native.new_http_archive( name = "nccl_archive", urls = [ "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", @@ -671,10 +637,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176", strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7", build_file = str(Label("//third_party:nccl.BUILD")), - repository = tf_repo_name, ) - temp_workaround_http_archive( + native.new_http_archive( name = "aws", urls = [ "https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", @@ -683,7 +648,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "f599b57aec4f03ad696044dd430b2d201864113937353adc346f53ad47991319", strip_prefix = "aws-sdk-cpp-1.0.90", build_file = str(Label("//third_party:aws.BUILD")), - repository = tf_repo_name ) java_import_external( @@ -711,7 +675,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): testonly_ = True, ) - temp_workaround_http_archive( + native.new_http_archive( name = "jemalloc", urls = [ "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", @@ -720,7 +684,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8", strip_prefix = "jemalloc-4.4.0", build_file = str(Label("//third_party:jemalloc.BUILD")), - repository = tf_repo_name, ) java_import_external( diff --git a/third_party/aws.BUILD b/third_party/aws.BUILD index bc9e37ffb3..bf5310aa16 100644 --- a/third_party/aws.BUILD +++ b/third_party/aws.BUILD @@ -7,21 +7,21 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("@%ws%//third_party:common.bzl", "template_rule") +load("@org_tensorflow//third_party:common.bzl", "template_rule") cc_library( name = "aws", srcs = select({ - "@%ws%//tensorflow:linux_x86_64": glob([ + "@org_tensorflow//tensorflow:linux_x86_64": glob([ "aws-cpp-sdk-core/source/platform/linux-shared/*.cpp", ]), - "@%ws%//tensorflow:darwin": glob([ + "@org_tensorflow//tensorflow:darwin": glob([ "aws-cpp-sdk-core/source/platform/linux-shared/*.cpp", ]), - "@%ws%//tensorflow:linux_ppc64le": glob([ + "@org_tensorflow//tensorflow:linux_ppc64le": glob([ "aws-cpp-sdk-core/source/platform/linux-shared/*.cpp", ]), - "@%ws%//tensorflow:raspberry_pi_armeabi": glob([ + "@org_tensorflow//tensorflow:raspberry_pi_armeabi": glob([ "aws-cpp-sdk-core/source/platform/linux-shared/*.cpp", ]), "//conditions:default": [], @@ -53,17 +53,17 @@ cc_library( "aws-cpp-sdk-core/include/aws/core/SDKConfig.h", ], defines = select({ - "@%ws%//tensorflow:linux_x86_64": [ + "@org_tensorflow//tensorflow:linux_x86_64": [ "PLATFORM_LINUX", "ENABLE_CURL_CLIENT", "ENABLE_NO_ENCRYPTION", ], - "@%ws%//tensorflow:darwin": [ + "@org_tensorflow//tensorflow:darwin": [ "PLATFORM_APPLE", "ENABLE_CURL_CLIENT", "ENABLE_NO_ENCRYPTION", ], - "@%ws%//tensorflow:linux_ppc64le": [ + "@org_tensorflow//tensorflow:linux_ppc64le": [ "PLATFORM_LINUX", "ENABLE_CURL_CLIENT", "ENABLE_NO_ENCRYPTION", diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD index 805a30d262..e311c7e758 100644 --- a/third_party/curl.BUILD +++ b/third_party/curl.BUILD @@ -6,7 +6,7 @@ licenses(["notice"]) # MIT/X derivative license exports_files(["COPYING"]) CURL_WIN_COPTS = [ - "/I%prefix%/curl/lib", + "/Iexternal/curl/lib", "/DHAVE_CONFIG_H", "/DCURL_DISABLE_FTP", "/DCURL_DISABLE_NTLM", @@ -224,14 +224,14 @@ cc_library( "lib/wildcard.h", "lib/x509asn1.h", ] + select({ - "@%ws%//tensorflow:darwin": [ + "@org_tensorflow//tensorflow:darwin": [ "lib/vtls/darwinssl.c", ], - "@%ws%//tensorflow:ios": [ + "@org_tensorflow//tensorflow:ios": [ "lib/vtls/darwinssl.c", ], - "@%ws%//tensorflow:windows": CURL_WIN_SRCS, - "@%ws%//tensorflow:windows_msvc": CURL_WIN_SRCS, + "@org_tensorflow//tensorflow:windows": CURL_WIN_SRCS, + "@org_tensorflow//tensorflow:windows_msvc": CURL_WIN_SRCS, "//conditions:default": [ "lib/vtls/openssl.c", ], @@ -248,10 +248,10 @@ cc_library( "include/curl/typecheck-gcc.h", ], copts = select({ - "@%ws%//tensorflow:windows": CURL_WIN_COPTS, - "@%ws%//tensorflow:windows_msvc": CURL_WIN_COPTS, + "@org_tensorflow//tensorflow:windows": CURL_WIN_COPTS, + "@org_tensorflow//tensorflow:windows_msvc": CURL_WIN_COPTS, "//conditions:default": [ - "-I%prefix%/curl/lib", + "-Iexternal/curl/lib", "-D_GNU_SOURCE", "-DHAVE_CONFIG_H", "-DCURL_DISABLE_FTP", @@ -261,14 +261,14 @@ cc_library( "-Wno-string-plus-int", ], }) + select({ - "@%ws%//tensorflow:darwin": [ + "@org_tensorflow//tensorflow:darwin": [ "-fno-constant-cfstrings", ], - "@%ws%//tensorflow:windows": [ + "@org_tensorflow//tensorflow:windows": [ # See curl.h for discussion of write size and Windows "/DCURL_MAX_WRITE_SIZE=16384", ], - "@%ws%//tensorflow:windows_msvc": [ + "@org_tensorflow//tensorflow:windows_msvc": [ # See curl.h for discussion of write size and Windows "/DCURL_MAX_WRITE_SIZE=16384", ], @@ -278,20 +278,20 @@ cc_library( }), includes = ["include"], linkopts = select({ - "@%ws%//tensorflow:android": [ + "@org_tensorflow//tensorflow:android": [ "-pie", ], - "@%ws%//tensorflow:darwin": [ + "@org_tensorflow//tensorflow:darwin": [ "-Wl,-framework", "-Wl,CoreFoundation", "-Wl,-framework", "-Wl,Security", ], - "@%ws%//tensorflow:ios": [], - "@%ws%//tensorflow:windows": [ + "@org_tensorflow//tensorflow:ios": [], + "@org_tensorflow//tensorflow:windows": [ "-Wl,ws2_32.lib", ], - "@%ws%//tensorflow:windows_msvc": [ + "@org_tensorflow//tensorflow:windows_msvc": [ "-Wl,ws2_32.lib", ], "//conditions:default": [ @@ -302,9 +302,9 @@ cc_library( deps = [ "@zlib_archive//:zlib", ] + select({ - "@%ws%//tensorflow:ios": [], - "@%ws%//tensorflow:windows": [], - "@%ws%//tensorflow:windows_msvc": [], + "@org_tensorflow//tensorflow:ios": [], + "@org_tensorflow//tensorflow:windows": [], + "@org_tensorflow//tensorflow:windows_msvc": [], "//conditions:default": [ "@boringssl//:ssl", ], @@ -312,7 +312,7 @@ cc_library( ) CURL_BIN_WIN_COPTS = [ - "/I%prefix%/curl/lib", + "/Iexternal/curl/lib", "/DHAVE_CONFIG_H", "/DCURL_DISABLE_LIBCURL_OPTION", ] @@ -406,10 +406,10 @@ cc_binary( "src/tool_xattr.h", ], copts = select({ - "@%ws%//tensorflow:windows": CURL_BIN_WIN_COPTS, - "@%ws%//tensorflow:windows_msvc": CURL_BIN_WIN_COPTS, + "@org_tensorflow//tensorflow:windows": CURL_BIN_WIN_COPTS, + "@org_tensorflow//tensorflow:windows_msvc": CURL_BIN_WIN_COPTS, "//conditions:default": [ - "-I%prefix%/curl/lib", + "-Iexternal/curl/lib", "-D_GNU_SOURCE", "-DHAVE_CONFIG_H", "-DCURL_DISABLE_LIBCURL_OPTION", diff --git a/third_party/gif.BUILD b/third_party/gif.BUILD index 27808a9d64..78fbd6c0e0 100644 --- a/third_party/gif.BUILD +++ b/third_party/gif.BUILD @@ -21,7 +21,7 @@ cc_library( ], hdrs = ["lib/gif_lib.h"], defines = select({ - #"@%ws%//tensorflow:android": [ + #"@org_tensorflow//tensorflow:android": [ ":android": [ "S_IREAD=S_IRUSR", "S_IWRITE=S_IWUSR", diff --git a/third_party/jemalloc.BUILD b/third_party/jemalloc.BUILD index a2addf2c66..1b0829b8fe 100644 --- a/third_party/jemalloc.BUILD +++ b/third_party/jemalloc.BUILD @@ -5,7 +5,7 @@ licenses(["notice"]) # BSD exports_files(["COPYING"]) -load("@%ws%//third_party:common.bzl", "template_rule") +load("@org_tensorflow//third_party:common.bzl", "template_rule") cc_library( name = "jemalloc_headers", @@ -97,10 +97,10 @@ cc_library( includes = ["include"], # pthread_atfork() is called for PPC. linkopts = select({ - "@%ws%//tensorflow:linux_ppc64le": [ + "@org_tensorflow//tensorflow:linux_ppc64le": [ "-lpthread", ], - "@%ws%//tensorflow:linux_x86_64": [ + "@org_tensorflow//tensorflow:linux_x86_64": [ "-lpthread", ], "//conditions:default": [ @@ -208,8 +208,8 @@ genrule( name = "size_classes_h", outs = ["include/jemalloc/internal/size_classes.h"], cmd = select({ - "@%ws%//tensorflow:linux_ppc64le": "$(location :size_classes_sh) \"3 4\" 3 16 2 >$@", - "@%ws%//tensorflow:linux_x86_64": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@", + "@org_tensorflow//tensorflow:linux_ppc64le": "$(location :size_classes_sh) \"3 4\" 3 16 2 >$@", + "@org_tensorflow//tensorflow:linux_x86_64": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@", "//conditions:default": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@", }), tools = [":size_classes_sh"], diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD index f6078052ec..e431f19382 100644 --- a/third_party/jpeg/jpeg.BUILD +++ b/third_party/jpeg/jpeg.BUILD @@ -5,7 +5,7 @@ licenses(["notice"]) # custom notice-style license, see LICENSE.md exports_files(["LICENSE.md"]) -load("@%ws%//third_party:common.bzl", "template_rule") +load("@org_tensorflow//third_party:common.bzl", "template_rule") libjpegturbo_nocopts = "-[W]error" diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl index 6574f25092..8b73ddabdd 100644 --- a/third_party/mkl/build_defs.bzl +++ b/third_party/mkl/build_defs.bzl @@ -60,7 +60,6 @@ mkl_repository = repository_rule( ], attrs = { "build_file": attr.label(), - "repository": attr.string(), "urls": attr.string_list(default = []), "sha256": attr.string(default = ""), "strip_prefix": attr.string(default = ""), diff --git a/third_party/nccl.BUILD b/third_party/nccl.BUILD index 8c7b9bdbe9..b2b8e18824 100644 --- a/third_party/nccl.BUILD +++ b/third_party/nccl.BUILD @@ -44,17 +44,17 @@ cc_library( "-O3", ] + cuda_default_copts(), linkopts = select({ - "@%ws%//tensorflow:android": [ + "@org_tensorflow//tensorflow:android": [ "-pie", ], - "@%ws%//tensorflow:darwin": [ + "@org_tensorflow//tensorflow:darwin": [ "-Wl,-framework", "-Wl,CoreFoundation", "-Wl,-framework", "-Wl,Security", ], - "@%ws%//tensorflow:ios": [], - "@%ws%//tensorflow:windows": [ + "@org_tensorflow//tensorflow:ios": [], + "@org_tensorflow//tensorflow:windows": [ "-DEFAULTLIB:ws2_32.lib", ], "//conditions:default": [ diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD index 9c00b7068a..fd48ed8941 100644 --- a/third_party/snappy.BUILD +++ b/third_party/snappy.BUILD @@ -50,8 +50,8 @@ genrule( "-e 's/@ac_cv_have_stddef_h@/1/g' " + "-e 's/@ac_cv_have_stdint_h@/1/g' " + select({ - "@%ws%//tensorflow:windows": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ", - "@%ws%//tensorflow:windows_msvc": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ", + "@org_tensorflow//tensorflow:windows": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ", + "@org_tensorflow//tensorflow:windows_msvc": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ", "//conditions:default": "-e 's/@ac_cv_have_sys_uio_h@/1/g' ", }) + "-e 's/@SNAPPY_MAJOR@/1/g' " + |