diff options
Diffstat (limited to 'tensorflow/core')
64 files changed, 1913 insertions, 671 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt new file mode 100644 index 0000000000..280148e032 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "LeakyRelu" + visibility: HIDDEN + summary: "Computes rectified linear: `max(features, features * alpha)`." +} diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt new file mode 100644 index 0000000000..e427526602 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt @@ -0,0 +1,24 @@ +op { + graph_op_name: "LeakyReluGrad" + visibility: HIDDEN + in_arg { + name: "gradients" + description: <<END +The backpropagated gradients to the corresponding LeakyRelu operation. +END + } + in_arg { + name: "features" + description: <<END +The features passed as input to the corresponding LeakyRelu operation, +OR the outputs of that operation (both work equivalently). +END + } + out_arg { + name: "backprops" + description: <<END +`gradients * (features > 0) + alpha * gradients * (featurs <= 0)`. +END + } + summary: "Computes rectified linear gradients for a LeakyRelu operation." +} diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt new file mode 100644 index 0000000000..b6a6dbdf54 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt @@ -0,0 +1,46 @@ +op { + graph_op_name: "StatelessRandomUniformInt" + visibility: HIDDEN + in_arg { + name: "shape" + description: <<END +The shape of the output tensor. +END + } + in_arg { + name: "seed" + description: <<END +2 seeds (shape [2]). +END + } + in_arg { + name: "minval" + description: <<END +Minimum value (inclusive, scalar). +END + } + in_arg { + name: "maxval" + description: <<END +Maximum value (exclusive, scalar). +END + } + out_arg { + name: "output" + description: <<END +Random values with specified shape. +END + } + attr { + name: "dtype" + description: <<END +The type of the output. +END + } + summary: "Outputs deterministic pseudorandom random integers from a uniform distribution." + description: <<END +The generated values follow a uniform distribution in the range `[minval, maxval)`. + +The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. +END +} diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index db137f1a19..e81e61b633 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -466,23 +466,23 @@ Graph* GetConstantGraph( bool ReplaceTensorWithConstant( Graph* graph, Device* partition_device, NodeAndOutput tensor, const Tensor& constant, const gtl::FlatSet<Node*>& control_deps, - int64 max_constant_size_in_bytes, bool disable_memory_output_type_check, + int64 max_constant_size_in_bytes, const ConstantFoldNameGenerator& generate_new_name) { // Be conservative when replacing a tensor with a constant, when not // running on CPU. // 1) Do not replace another constant. // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY // constraint, do not replace it. - // 3) If the size of the constant in bytes is too large (> + // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY + // constraint, do not replace it. + // 4) If the size of the constant in bytes is too large (> // max_constant_in_bytes), do not replace it. This prevents the size of the // Graph from growing too large. - // 4) If the constant op created does not have a kernel implementation + // 5) If the constant op created does not have a kernel implementation // for the device, do not use it. // TODO(keveman): Consider adding a new constant op that has a kernel // implementation for all types, but with HostMemory constraint on it's // output. - // 5) If the constant op for the device has different output memory type - // from the original op output memory type, do not replace it. if (tensor.first->IsConstant()) { return false; } @@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant( return false; } bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32; - if (memory_type == HOST_MEMORY && !is_int32) { + if ((memory_type == HOST_MEMORY && !is_int32) || + (memory_type == DEVICE_MEMORY && is_int32)) { return false; } } @@ -535,25 +536,6 @@ bool ReplaceTensorWithConstant( if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) { return false; } - if (!disable_memory_output_type_check) { - if (partition_device && device_type != DEVICE_CPU) { - MemoryType original_output_memory_type; - if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second, - &original_output_memory_type) - .ok()) { - return false; - } - MemoryType const_output_memory_type; - if (!MemoryTypeForOutput(device_type, graph, constant_node, 0, - &const_output_memory_type) - .ok()) { - return false; - } - if (original_output_memory_type != const_output_memory_type) { - return false; - } - } - } for (auto edge : edges_to_remove) { graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input()); graph->RemoveEdge(edge); @@ -660,8 +642,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts, constant_control_deps[tensors_to_replace[c].first]; if (ReplaceTensorWithConstant( graph, partition_device, tensors_to_replace[c], outputs[c], - control_deps, opts.max_constant_size_in_bytes, - opts.disable_memory_output_type_check, generate_new_name)) { + control_deps, opts.max_constant_size_in_bytes, generate_new_name)) { ++num_nodes_replaced; } } diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h index 4c71b7bd27..a9a84f761b 100644 --- a/tensorflow/core/common_runtime/constant_folding.h +++ b/tensorflow/core/common_runtime/constant_folding.h @@ -45,10 +45,6 @@ struct ConstantFoldingOptions { // optimization. int64 max_constant_size_in_bytes = 10 * 1024 * 1024; - // If disable_memory_output_type_check is true, we will disable output memory - // type check for constant node replacement. - bool disable_memory_output_type_check = false; - // A generator for the name suffix of constant folded nodes. A // default id generator that monotonically increases is used if nullptr is // passed. diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 2c48084cab..40ec1502da 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/context.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -1240,6 +1241,7 @@ class ExecutorState { StepStatsCollectorInterface* const stats_collector_; const tracing::TraceCollector* const trace_collector_; const tracing::EventCollector* const event_collector_; + Context context_; // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper // instead of a pointer? (avoids having to delete). @@ -1367,6 +1369,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) trace_collector_(tracing::GetTraceCollector()), event_collector_( tracing::GetEventCollector(tracing::EventCategory::kCompute)), + context_(ContextKind::kThread), slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), call_frame_(args.call_frame), impl_(impl), @@ -1586,6 +1589,7 @@ bool MightTrace(const NodeItem& item, } void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { + WithContext wc(context_); const GraphView& gview = impl_->gview_; TaggedNodeSeq ready; TaggedNodeReadyQueue inline_ready; diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index 91194bc86f..37a979a8f1 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -39,8 +39,7 @@ void GraphOptimizer::Optimize( const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, const std::function<bool(const Node*)>& cse_consider_fn, - const std::function<bool(const Node*)>& cf_consider_fn, - bool cf_disable_memory_output_type_check) { + const std::function<bool(const Node*)>& cf_consider_fn) { Graph* g = graph->get(); DumpGraph("Initial", g); @@ -65,8 +64,6 @@ void GraphOptimizer::Optimize( ConstantFoldingOptions cf_opts; cf_opts.shape_map = shape_map; cf_opts.consider = cf_consider_fn; - cf_opts.disable_memory_output_type_check = - cf_disable_memory_output_type_check; if (opts_.max_folded_constant_in_bytes() > 0) { cf_opts.max_constant_size_in_bytes = opts_.max_folded_constant_in_bytes(); diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 8954e9612d..789cc56942 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -47,16 +47,13 @@ class GraphOptimizer { // returns true will be considered for CSE. // If cf_consider_fn is not null then only nodes for which cf_consider_fn // returns true will be considered for CF. - // If cf_disable_memory_output_type_check is true, CF will discard output - // memory type check for constant node replacement. void Optimize( FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, const std::function<bool(const Node*)>& cse_consider_fn = nullptr, - const std::function<bool(const Node*)>& cf_consider_fn = nullptr, - bool cf_disable_memory_output_type_check = false); + const std::function<bool(const Node*)>& cf_consider_fn = nullptr); const OptimizerOptions& options() { return opts_; } diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index a02084f223..9306386117 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name, then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()), else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) { TF_CHECK_OK(if_op_->input_node(0, &pred_)); + then_call_builder_.Device(if_op_->requested_device()); + else_call_builder_.Device(if_op_->requested_device()); } Status CondBuilder::CreatePivotNodes() { @@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() { NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry()) .Input(NodeOut(pred_, 0)) .Input(NodeOut(pred_, 0)) + .Device(if_op_->requested_device()) .Finalize(graph_, &switch_pred)); control_predecessor_ = switch_pred; TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry()) .Input(switch_pred, kElseBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_f_)); TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry()) .Input(switch_pred, kThenBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_t_)); return Status::OK(); } @@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) { NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry()) .Input(src, src_output) .Input(pred_, 0) + .Device(if_op_->requested_device()) .Finalize(graph_, &input)); then_call_builder_.Input(input, kThenBranch); else_call_builder_.Input(input, kElseBranch); @@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() { TF_RETURN_IF_ERROR( NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry()) .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)}) + .Device(if_op_->requested_device()) .Finalize(graph_, &merges[i])); outputs_[i] = NodeOut(merges[i], 0); } @@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib, Status CondBuilder::BuildLoweredIfOutput() { // Build the identity node output. NodeBuilder ib(name_, "IdentityN"); - ib.Input(outputs_); + ib.Input(outputs_).Device(if_op_->requested_device()); return ib.Finalize(graph_, &lowered_if_output_); } diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index fa4d1eda62..9488a44778 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -288,6 +288,11 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port, "output_port '", output_port, "' is out of range, ", "node '", node->name(), "' has ", node->num_outputs(), " outputs"); } + // Note: it's possible, if the node's been updated, that the shape inference + // context doesn't have the right number of outputs. + if (node->num_outputs() > c->num_outputs()) { + TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs())); + } // Check compatibility, and merge the shapes. ShapeHandle existing_shape = c->output(output_port); diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 3e77028a5f..4dcc80680f 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -239,6 +239,15 @@ void InferenceContext::PreInputInit( output_handle_shapes_and_types_.resize(num_outputs); } +Status InferenceContext::ExpandOutputs(int new_output_size) { + if (new_output_size < outputs_.size()) { + return errors::InvalidArgument("Trying to reduce number of outputs of op."); + } + outputs_.resize(new_output_size, nullptr); + output_handle_shapes_and_types_.resize(new_output_size); + return Status::OK(); +} + void InferenceContext::PostInputInit( std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) { int num_inputs_from_node_def = 0; diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 81258b55b3..e3885b7d9e 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -323,13 +323,13 @@ class InferenceContext { return input_tensors_as_shapes_; } - ShapeHandle output(int64 idx) const { return outputs_[idx]; } - void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; } + ShapeHandle output(int64 idx) const { return outputs_.at(idx); } + void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; } Status set_output(StringPiece output_name, const std::vector<ShapeHandle>& shapes); int num_outputs() const { return outputs_.size(); } - ShapeHandle output(int idx) const { return outputs_[idx]; } + ShapeHandle output(int idx) const { return outputs_.at(idx); } Status output(StringPiece output_name, std::vector<ShapeHandle>* output) const; @@ -645,6 +645,9 @@ class InferenceContext { return merged_dims_; } + // Adds new outputs; useful when mutating the graph. + Status ExpandOutputs(int new_output_size); + private: // Creates and stores shapes for use in InferenceContext. class ShapeManager { diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 7a4a0096fa..6f068546d2 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -142,6 +142,19 @@ void Node::Clear() { assigned_device_name_index_ = 0; } +void Node::UpdateProperties() { + DataTypeVector inputs; + DataTypeVector outputs; + Status status = + InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs); + if (!status.ok()) { + LOG(ERROR) << "Failed at updating node: " << status; + return; + } + props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def, + inputs, outputs); +} + const string& Node::name() const { return props_->node_def.name(); } const string& Node::type_string() const { return props_->node_def.op(); } const NodeDef& Node::def() const { return props_->node_def; } diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 2944951f82..228b1331d9 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -171,6 +171,7 @@ class Node { template <typename T> void AddAttr(const string& name, const T& val) { SetAttrValue(val, AddAttrHelper(name)); + UpdateProperties(); } void ClearAttr(const string& name); @@ -211,6 +212,10 @@ class Node { // e.g. in AddAttr. void MaybeCopyOnWrite(); + // Called after an attr has changed. Decides whether we need to update some + // property of the node (stored in props_). + void UpdateProperties(); + AttrValue* AddAttrHelper(const string& name); // A set of mutually exclusive classes for different kinds of nodes, diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index d92874909f..68a20fcc5f 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -140,10 +140,10 @@ void NodeBuilder::AddIndexError(const Node* node, int i) { strings::StrCat("Attempt to add nullptr Node to node with type ", def_builder_.op_def().name())); } else { - errors_.emplace_back( - strings::StrCat("Attempt to add output ", i, " of ", node->name(), - " not in range [0, ", node->num_outputs(), - ") to node with type ", def_builder_.op_def().name())); + errors_.emplace_back(strings::StrCat( + "Attempt to add output ", i, " of ", node->name(), " not in range [0, ", + node->num_outputs(), ") to node with type ", + def_builder_.op_def().name(), ". Node: ", node->DebugString())); } } diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 1b5a215987..cbf5c8e038 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -102,15 +102,19 @@ bool IsConjugateTranspose(const NodeDef& node) { } bool IsControlFlow(const NodeDef& node) { - // clang-format off - return node.op() == "ControlTrigger" || - node.op() == "Enter" || - node.op() == "Exit" || - node.op() == "LoopCond" || - node.op() == "Merge" || - node.op() == "NextIteration" || - node.op() == "Switch"; - // clang-format on + // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative + // string comparison. + static const gtl::FlatSet<string>* const kControFlowOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ + "ControlTrigger", + "Enter", + "Exit", + "LoopCond", + "Merge", + "NextIteration", + "Switch", + })); + return kControFlowOps->count(node.op()) > 0; } bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD index 37aa24b947..985d6c6c3a 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD +++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD @@ -13,9 +13,19 @@ VECTORIZER_DEPS = [ ] + tf_protos_all() cc_library( + name = "wrapped_tensor", + hdrs = ["wrapped_tensor.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + ], +) + +cc_library( name = "vectorizer", hdrs = ["vectorizer.h"], deps = [ + ":wrapped_tensor", "//tensorflow/core:core_cpu", "//tensorflow/core:lib", ] + tf_protos_all(), diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc index 3af6bab409..f445157531 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc @@ -19,13 +19,13 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { +namespace { class CastVectorizer : public Vectorizer { public: Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* input_ports, - std::vector<Port>* output_ports) override { + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { Status s; if (node.num_inputs() != 1) { return errors::Internal("Cast op should only have one input."); @@ -35,15 +35,17 @@ class CastVectorizer : public Vectorizer { auto new_cast_node = outer_scope->AddNode(node.def(), &s); TF_RETURN_IF_ERROR(s); - // Add input and output mappings - input_ports->push_back({new_cast_node, 0}); - output_ports->push_back({new_cast_node, 0}); + outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node, + 0); + + // Add output mappings + outputs->push_back({new_cast_node, 0, true}); return Status::OK(); } }; REGISTER_VECTORIZER("Cast", CastVectorizer); -} // namespace vectorization_utils +} // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc index 74ce520ce1..f1ba741821 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc @@ -19,15 +19,15 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { +namespace { class UnpackVectorizer : public Vectorizer { public: Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* input_ports, - std::vector<Port>* output_ports) override { + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { Status s; - if (node.num_inputs() != 1) { + if (node.num_inputs() != 1 || inputs.size() != 1) { return errors::Internal("Unpack op should only have one input."); } @@ -39,13 +39,13 @@ class UnpackVectorizer : public Vectorizer { int new_axis = node.def().attr().at("axis").i() + 1; new_unpack_node->AddAttr("axis", new_axis); - // Add the input mappings - input_ports->push_back({new_unpack_node, 0}); + outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, + new_unpack_node, 0); // Add the output mappings int num = node.def().attr().at("num").i(); for (int i = 0; i < num; ++i) { - output_ports->push_back({new_unpack_node, i}); + outputs->push_back({new_unpack_node, i, true}); } return Status::OK(); @@ -54,6 +54,6 @@ class UnpackVectorizer : public Vectorizer { REGISTER_VECTORIZER("Unpack", UnpackVectorizer); -} // namespace vectorization_utils +} // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h index 56eb88c95e..8d4676aae0 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h @@ -18,15 +18,12 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace grappler { -namespace vectorization_utils { - -// Describes a tensor with its operation Node and output position -typedef std::pair<Node*, int> Port; // Interface for vectorization of TensorFlow operations. See `CastVectorizer` // for an example. @@ -36,17 +33,17 @@ class Vectorizer { // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope` // that produce the same vector output(s) as executing `node`'s op - // on elements of the vector inputs. The new Node(s) collectively have the + // on elements of `inputs`. The new Node(s) collectively have the // same number of input and output ports as the node being converted. - // Adds mappings for the new nodes' input and output ports to `inputs` and - // `outputs` respectively, where the i'th Port in inputs/outputs - // corresponds to the i'th input/output port of the node to be converted. + // Adds edges between the newly created nodes and nodes in `inputs`, and adds + // mappings to the new nodes' output ports to `outputs`, where the i'th + // value in `outputs` corresponds to the i'th output port of the node + // to be converted. virtual Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* input_ports, - std::vector<Port>* output_ports) = 0; + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) = 0; }; -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc index a6551e36ac..e1cf77a7d5 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc @@ -19,7 +19,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { VectorizerRegistry* VectorizerRegistry::Global() { static VectorizerRegistry* registry = new VectorizerRegistry; @@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type, vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>( op_type, std::move(vectorizer))); } -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h index 16159d47ca..ad54c74933 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h @@ -23,7 +23,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { // A global VectorizerRegistry is used to hold all the vectorizers. class VectorizerRegistry { @@ -59,16 +58,12 @@ class VectorizerRegistration { #define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \ REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) -#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \ - static ::tensorflow::grappler::vectorization_utils:: \ - vectorizer_registration::VectorizerRegistration \ - vectorizer_registration_##ctr( \ - op_type, \ - ::std::unique_ptr< \ - ::tensorflow::grappler::vectorization_utils::Vectorizer>( \ - new vectorizer())) +#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \ + static ::tensorflow::grappler::vectorizer_registration:: \ + VectorizerRegistration vectorizer_registration_##ctr( \ + op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \ + new vectorizer())) -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc index 663ceba027..054aeb9a8f 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc @@ -20,13 +20,12 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { class TestVectorizer : public Vectorizer { public: Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* inputs, - std::vector<Port>* outputs) override { + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { return Status::OK(); } }; @@ -43,10 +42,10 @@ TEST(TestVectorizer, TestTestVectorizer) { NodeDef node_def; Status s; Node* node = g.AddNode(node_def, &s); - std::vector<Port> inputs, outputs; - EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok()); + std::vector<WrappedTensor> inputs, outputs; + EXPECT_TRUE( + vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok()); } -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h new file mode 100644 index 0000000000..4439b4ab4e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace grappler { + +// Represents a tensor that has been vectorized. +struct WrappedTensor { + Node* const node; + const int output_index; + + // Whether the tensor is stacked, i.e. represents the results of applying + // the operation on all slices of the input, where each row i of the + // tensor corresponds to the op's output on slice i of the input. False + // if the tensor is not stacked, i.e. represents the result of the op on + // a single slice of the input, where the result does not vary between + // slices. + bool stacked; + + WrappedTensor(Node* node, int output_index, bool stacked) + : node(node), output_index(output_index), stacked(stacked) {} +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index 344c420902..ba857ab5d9 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -45,22 +45,6 @@ namespace { // Describes a tensor with its operation Node and output position typedef std::pair<Node*, int> TensorDesc; -// Equivalent to python Pfor's WrappedTensor struct -struct WrappedTensor { - TensorDesc tensor; - - // Whether the tensor is stacked, i.e. represents the results of applying - // the operation on all slices of the input, where each row i of the - // tensor corresponds to the op's output on slice i of the input. False - // if the tensor is not stacked, i.e. represents the result of the op on - // a single slice of the input, where the result does not vary between - // slices. - bool stacked; - - WrappedTensor(TensorDesc&& tensor, bool stacked) - : tensor(std::move(tensor)), stacked(stacked) {} -}; - const char* const kRetValOp = "_Retval"; void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src, @@ -239,34 +223,48 @@ Status Vectorization::AddConversionMapping(Node* op_node) { return errors::Unimplemented("No vectorizer registered for op: ", op_node->type_string()); } - std::vector<Port> input_ports, output_ports; - input_ports.reserve(op_node->num_inputs()); - output_ports.reserve(op_node->num_outputs()); - TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), - &input_ports, &output_ports)); + std::vector<WrappedTensor> inputs, outputs; + inputs.reserve(op_node->num_inputs()); + outputs.reserve(op_node->num_outputs()); std::vector<const Edge*> input_edges; TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges)); - if (op_node->num_outputs() != output_ports.size() || - op_node->num_inputs() != input_ports.size() || - input_edges.size() != input_ports.size()) { - return errors::Internal("Vectorizer inputs/outputs don't match."); - } - - // Promote the inputs of the op to MapDefun outputs and connect the edges - // accordingly. + // The inputs for the node to be converted may already have been converted + // themselves. For those that are not, we promote them to MapDefun outputs. for (size_t i = 0; i < op_node->num_inputs(); ++i) { auto edge = input_edges[i]; - TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_, - {edge->src(), edge->src_output()})); - outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1, - input_ports[i].first, input_ports[i].second); + if (auto found = gtl::FindOrNull(conversion_map_, + {edge->src(), edge->src_output()})) { + inputs.push_back(*found); + } else { + // TODO(rachelim): Handle the case where unconverted inputs are unstacked. + // We assume that all unconverted inputs will be stacked, since we + // converted all unstacked nodes in `Initialize`. However, it's actually + // possible that yet-unconverted nodes may produce unstacked outputs after + // they are vectorized. (For example, see the "Shape" converter in + // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects + // an unstacked input but receives a stacked one, vectorizer->Vectorize + // will return an error. + TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_, + {edge->src(), edge->src_output()})); + int output_index = map_defun_fn_->ret_nodes.size() - 1; + inputs.push_back({map_defun_node_, output_index, true}); + } + } + + TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), + std::move(inputs), &outputs)); + + if (op_node->num_outputs() != outputs.size()) { + return errors::Internal( + "Number of vectorizer outputs does not match. Expected: ", + op_node->num_outputs(), " Actual: ", outputs.size()); } // Add output mappings. for (size_t i = 0; i < op_node->num_outputs(); ++i) { - conversion_map_.insert({{op_node, i}, {std::move(output_ports[i]), true}}); + conversion_map_.insert({{op_node, i}, outputs[i]}); } return Status::OK(); @@ -281,25 +279,22 @@ Status Vectorization::ConvertOutput(int output_position) { TensorDesc output({ret_edge->src(), ret_edge->src_output()}); TensorDesc converted_output; - if (auto found = gtl::FindOrNull(conversion_map_, output)) { - // It's possible the output already has a mapping, if it comes from a node - // that has already been converted. - if (found->stacked) { - converted_output = found->tensor; - } else { - // Some outputs may be unstacked if they don't derive from arg nodes - // (for example, if a function returns a constant). For these, we - // have to add extra nodes to tile it in the 0th dimension. - TF_RETURN_IF_ERROR(StackTensor(found, &converted_output)); - } - } else { - // Note: All unstacked nodes are converted ahead of time in `Initialize`, - // and here we assume that all op vectorizers create only stacked outputs. - // This may not hold in the future, as more vectorizers are added that - // may actually create unstacked outputs. For example, see the `Shape` - // converter in third_party/tensorflow/python/ops/parallel_for/pfor.py + + // It's possible the output already has a mapping, if it comes from a node + // that has already been converted. + auto found = gtl::FindOrNull(conversion_map_, output); + if (!found) { TF_RETURN_IF_ERROR(AddConversionMapping(output.first)); - converted_output = conversion_map_.at(output).tensor; + found = &conversion_map_.at(output); + } + + if (found->stacked) { + converted_output = {found->node, found->output_index}; + } else { + // Some outputs may be unstacked if they don't derive from arg nodes + // (for example, if a function returns a constant). For these, we + // have to add extra nodes to tile it in the 0th dimension. + TF_RETURN_IF_ERROR(StackTensor(found, &converted_output)); } ReplaceEdgeSources({map_defun_node_, output_position}, converted_output, @@ -455,7 +450,7 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked, Node* ones_shape; TF_RETURN_IF_ERROR(node_builder("Shape") - .Input(unstacked->tensor.first) // input + .Input(unstacked->node) // input .Finalize(g, &ones_shape)); Node* ones; @@ -473,8 +468,8 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked, Node* expand_dims; TF_RETURN_IF_ERROR(node_builder("ExpandDims") - .Input(unstacked->tensor.first) // input - .Input(const_0) // dim + .Input(unstacked->node) // input + .Input(const_0) // dim .Finalize(g, &expand_dims)); TF_RETURN_IF_ERROR(node_builder("Tile") @@ -491,11 +486,11 @@ Status Vectorization::AddArgNodeMappings() { TF_RETURN_IF_ERROR(map_defun_node_->input_node( arg_node->attrs().Find("index")->i(), &input_node)); - conversion_map_.insert({{arg_node, 0}, {{input_node, 0}, true}}); + conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}}); // Control inputs conversion_map_.insert({{arg_node, Graph::kControlSlot}, - {{input_node, Graph::kControlSlot}, true}}); + {input_node, Graph::kControlSlot, true}}); } return Status::OK(); } @@ -541,7 +536,7 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, if (auto found = gtl::FindOrNull(conversion_map_, {edge->src(), edge->src_output()})) { - outer_scope_->AddEdge(found->tensor.first, found->tensor.second, node, + outer_scope_->AddEdge(found->node, found->output_index, node, edge->dst_input()); } else { status->Update(errors::Internal( @@ -552,11 +547,10 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, // Add output mappings for (int i = 0; i < tensor.first->num_outputs(); ++i) { - conversion_map_.insert( - {{tensor.first, i}, WrappedTensor({node, i}, false)}); + conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)}); } conversion_map_.insert({{tensor.first, Graph::kControlSlot}, - WrappedTensor({node, Graph::kControlSlot}, false)}); + WrappedTensor(node, Graph::kControlSlot, false)}); return true; } diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h index 765dd13263..bd6bf9f860 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h @@ -16,8 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ +#include <atomic> #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace tensorflow { namespace grappler { @@ -29,6 +32,7 @@ struct GrapplerItem; // optimization of a GrapplerItem for running on a cluster. class GraphOptimizer { public: + GraphOptimizer() : is_cancelled_(false) {} virtual ~GraphOptimizer() {} virtual string name() const = 0; @@ -45,8 +49,25 @@ class GraphOptimizer { // call to Optimize) performed. Lower "result" scores are better. virtual void Feedback(Cluster* cluster, const GrapplerItem& item, const GraphDef& optimized_graph, double result) = 0; + + // Best effort cancellation. Sets is_cancelled to true and requests that the + // optimizer returns as soon as possible from active calls to Optimize() or + // FeedBack(). + void Cancel() { is_cancelled_ = true; } + + bool is_cancelled() const { return is_cancelled_; } + + private: + std::atomic<bool> is_cancelled_; }; +#define GRAPPLER_RETURN_IF_CANCELLED() \ + do { \ + if (is_cancelled()) { \ + return errors::DeadlineExceeded(this->name(), " was cancelled."); \ + } \ + } while (0) + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index c3d70a1fdf..7488cedec5 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" + +#include <memory> + #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -37,7 +40,11 @@ limitations under the License. #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -107,13 +114,29 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer( MK_OPT("scoped_allocator", new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); - MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization())); + MK_OPT("pin_to_host", + new PinToHostOptimizer(cfg_.pin_to_host_optimization())); return std::unique_ptr<GraphOptimizer>(); } #undef MK_OPT +MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg) + : cpu_device_(cpu_device), cfg_(cfg) { + // TODO(rmlarsen): Increase kNumThreads to, say, port::NumSchedulableCPUs() + // if we want to the threadpool for parallelizing Grappler + const int kNumThreads = 1; + thread_pool_ = absl::make_unique<thread::ThreadPool>( + Env::Default(), "MetaOptimizerThreadPool", kNumThreads); +} + +MetaOptimizer::~MetaOptimizer() { + // The ThreadPool destructor waits for threads to finish, so we don't + // pull the rug out from under them. + thread_pool_.reset(); +} + Status MetaOptimizer::InitializeOptimizers( std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { if (cfg_.disable_meta_optimizer()) { @@ -139,7 +162,7 @@ Status MetaOptimizer::InitializeOptimizers( if (cfg_.remapping() != RewriterConfig::OFF) { optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping())); } - if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) { + if (cfg_.pin_to_host_optimization() != RewriterConfig::OFF) { optimizers->push_back(MakeUnique<PinToHostOptimizer>()); } if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) { @@ -309,6 +332,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, VLOG(4) << "Starting optimization iteration " << iteration; for (const auto& optimizer : optimizers) { + GRAPPLER_RETURN_IF_CANCELLED(); // Some optimizers can run only once. if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue; // Some must run only on the last iteration. @@ -367,6 +391,7 @@ Status MetaOptimizer::RunOptimizer( // resets optimized_graph to an empty graph. optimized_graph->Swap(&optimized_item->graph); *optimized_graph = GraphDef(); + // TODO(rmlarsen): Add timeout for individual optimizers. Status status = optimizer->Optimize(cluster, *optimized_item, optimized_graph); uint64 end_us = Env::Default()->NowMicros(); @@ -388,14 +413,15 @@ Status MetaOptimizer::RunOptimizer( return status; } -Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) { +Status MetaOptimizer::OptimizeMainGraphAndFunctionLibrary( + Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { VLOG(1) << "Starting optimization for grappler item: " << item.id; optimization_results_.clear(); // 1. Optimize main graph TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph)); VLOG(1) << "Optimized main graph."; + GRAPPLER_RETURN_IF_CANCELLED(); // Skip optimizing functions if this is a TPU graph. Currently, Grappler // passes do not handle TPU functions correctly in a variety of ways (Note @@ -431,6 +457,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimize_function_library = false; for (const FunctionDef& func : optimized_graph->library().function()) { + GRAPPLER_RETURN_IF_CANCELLED(); + const string& func_name = func.signature().name(); // Skip already optimized functions. @@ -505,6 +533,43 @@ void MetaOptimizer::PrintResult() { } } +Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + const int64 kFiveMinutesInUsec = 5 * 60 * 1000 * 1000; + const int64 timeout_usec = (cfg_.meta_optimizer_timeout_ms() == 0 + ? kFiveMinutesInUsec + : cfg_.meta_optimizer_timeout_ms() * 1000); + if (timeout_usec < 0) { + return OptimizeMainGraphAndFunctionLibrary(cluster, item, optimized_graph); + } + + GraphDef optimized_with_timeout; + Status status; + Notification done; + thread_pool_->Schedule( + [this, cluster, &done, &optimized_with_timeout, &item, &status]() { + status = this->OptimizeMainGraphAndFunctionLibrary( + cluster, item, &optimized_with_timeout); + done.Notify(); + }); + + const bool notified = WaitForNotificationWithTimeout(&done, timeout_usec); + if (notified && status.ok()) { + optimized_graph->Swap(&optimized_with_timeout); + } else { + *optimized_graph = item.graph; + if (!notified) { + this->Cancel(); + done.WaitForNotification(); + status = errors::DeadlineExceeded( + "Grappler MetaOptimizer timed out after ", + static_cast<float>(timeout_usec) / (1000 * 1000), " seconds"); + LOG(WARNING) << status.error_message(); + } + } + return status; +} + void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, const GraphDef& pruned_graph, double result) { // Nothing to do for MetaOptimizer. @@ -527,7 +592,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) { cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT || cfg.debug_stripper() == RewriterConfig::ON || cfg.scoped_allocator_optimization() == RewriterConfig::ON || - cfg.pin_to_host_optimization() == RewriterConfig::ON || + cfg.pin_to_host_optimization() != RewriterConfig::OFF || !cfg.optimizers().empty() || !cfg.custom_optimizers().empty(); } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 99a0a33ffa..35d6a4559b 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { @@ -28,9 +29,8 @@ namespace grappler { // Run the other grappler optimizers based on the specified rewriter config. class MetaOptimizer : public GraphOptimizer { public: - MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg) - : cpu_device_(cpu_device), cfg_(cfg) {} - ~MetaOptimizer() override = default; + MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg); + ~MetaOptimizer(); string name() const override { return "meta_optimizer"; }; @@ -65,9 +65,18 @@ class MetaOptimizer : public GraphOptimizer { Status OptimizeGraph(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph); + // Run optimization passes over the main graph and for functions in the + // function library. + Status OptimizeMainGraphAndFunctionLibrary(Cluster* cluster, + const GrapplerItem& item, + GraphDef* optimized_graph); + DeviceBase* const cpu_device_; // may be NULL RewriterConfig cfg_; + // Thread pool used for launching optimizers asynchronously. + std::unique_ptr<thread::ThreadPool> thread_pool_; + struct OptimizerResult { string optimizer_name; string result; diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 3f3f43382f..7f1dd91f09 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -461,6 +461,68 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) { EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites); } +class SleepingOptimizer : public CustomGraphOptimizer { + public: + SleepingOptimizer() {} + string name() const override { return "test_optimizer"; } + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override { + *optimized_graph = item.graph; + optimized_graph->add_node(); + sleep(1); + return Status::OK(); + } + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} +}; + +REGISTER_GRAPH_OPTIMIZER(SleepingOptimizer); + +TEST_F(MetaOptimizerTest, OptimizerTimesOut) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + RewriterConfig rewriter_config; + rewriter_config.add_optimizers("SleepingOptimizer"); + rewriter_config.set_min_graph_nodes(-1); + rewriter_config.set_meta_optimizer_timeout_ms(1500); + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + EXPECT_EQ(status.error_message(), + "Grappler MetaOptimizer timed out after 1.5 seconds"); + // Make sure the graph was reverted to the original regardless of when the + // optimizer timed out. + CompareGraphs(item.graph, output); +} + +TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + RewriterConfig rewriter_config; + rewriter_config.add_optimizers("SleepingOptimizer"); + rewriter_config.set_min_graph_nodes(-1); + rewriter_config.set_meta_optimizer_timeout_ms(1500); + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE); + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_EQ(item.graph.node_size() + 1, output.node_size()); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc index 8ed4271fa4..29a3b2b74c 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -25,16 +25,29 @@ limitations under the License. #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace grappler { + namespace internal { +namespace { // TODO(williamchan): Change this constant to be something smarter, maybe // dynamically determined. constexpr int64 kTensorMaxSize = 64; +struct OpDevicePortHasher { + std::size_t operator()(const std::tuple<string, string, int>& x) const { + uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x))); + + return Hash64Combine(code, hash<int>()(std::get<2>(x))); + } +}; +using OpDevicePortOnHostMap = + gtl::FlatMap<std::tuple<string, string, int>, bool, OpDevicePortHasher>; + // All the nodes that should be blacklisted and not swapped. bool IsBlacklisted(const NodeDef& node) { return @@ -82,10 +95,10 @@ Status TryFindKernelDef(const std::vector<DeviceType>& devices, // Checks if a node's output port is host friendly. // Roughly this means checking if the output port is on Host memory. -Status IsNodeOutputPortHostFriendly(const GraphView& graph, - GraphProperties* properties, - const NodeDef& node, int port_id, - bool* is_candidate) { +Status IsNodeOutputPortHostFriendly( + const GraphView& graph, GraphProperties* properties, const NodeDef& node, + int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache, + bool* is_candidate) { *is_candidate = false; // Make sure we are not a blacklisted op. @@ -117,7 +130,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, for (const auto& fanin : graph.GetFanins(node, false)) { bool fanin_candidate = false; TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( - graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); + graph, properties, *fanin.node, fanin.port_id, + op_device_outport_pinned_to_host_cache, &fanin_candidate)); if (!fanin_candidate) { return Status::OK(); } @@ -132,11 +146,22 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, return Status::OK(); } + // Check `op_device_outport_pinned_to_host_cache` for our + // {op, device, port_id} combo to see if the arg is pinned on Host. + const std::tuple<string, string, int> cache_key(node.op(), node.device(), + port_id); + auto it = op_device_outport_pinned_to_host_cache->find(cache_key); + if (it != op_device_outport_pinned_to_host_cache->end()) { + *is_candidate = it->second; + return Status::OK(); + } + // Check if op's output port is pinned to HostMemory. const OpDef* op = nullptr; Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); if (!s.ok()) { LOG(WARNING) << "Could not find OpDef for : " << node.op(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); return Status::OK(); } @@ -146,6 +171,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, LOG(WARNING) << "Invalid port: " << port_id << "!\n" << node.DebugString() << "\n" << op->DebugString(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); return Status::OK(); } @@ -155,6 +181,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, &kernel); if (!s.ok()) { LOG(INFO) << "Could not find KernelDef for: " << node.op(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); return Status::OK(); } @@ -166,22 +193,35 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, } } + op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate); + return Status::OK(); } // Checks if a node's input port is Host friendly. // Roughly this means checking if the input port is on Host memory. -bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) { +bool IsNodeInputPortHostFriendly( + const NodeDef& node, int port_id, + OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) { // If node is on Host, assume its inputs are Host friendly. if (str_util::StrContains(node.device(), DEVICE_CPU)) { return true; } + // Check `op_device_inport_pinned_to_host_cache` for our + // {op, device, port_id} combo to see if the arg is pinned on Host. + std::tuple<string, string, int> cache_key(node.op(), node.device(), port_id); + auto it = op_device_inport_pinned_to_host_cache->find(cache_key); + if (it != op_device_inport_pinned_to_host_cache->end()) { + return it->second; + } + // Check if op's input port is pinned to HostMemory. const OpDef* op = nullptr; Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); if (!s.ok()) { LOG(WARNING) << "Could not find OpDef for : " << node.op(); + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); return false; } const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id); @@ -192,16 +232,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) { {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel); if (!s.ok()) { LOG(INFO) << "Could not find KernelDef for: " << node.op(); + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); return false; } // Check if the input_arg is pinned to Host. for (const string& host_memory_arg : kernel->host_memory_arg()) { if (op->input_arg(input_arg_id).name() == host_memory_arg) { + op_device_inport_pinned_to_host_cache->emplace(cache_key, true); return true; } } + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); + return false; } @@ -211,18 +255,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) { // 2] Check if node can run on Host. // 3] Check all input/outputs are Host "friendly" (atm, friendly means small, // ints, and pinned to Host). -Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, - const NodeDef& node, bool* is_candidate) { +Status IsNodeHostCandidate( + const GraphView& graph, GraphProperties* properties, const NodeDef& node, + OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache, + bool* is_candidate) { *is_candidate = false; - // Check if node already on CPU. - if (str_util::StrContains(node.device(), DEVICE_CPU)) { - *is_candidate = true; + // Skip these node types. + if (IsBlacklisted(node)) { return Status::OK(); } - // Skip these node types. - if (IsBlacklisted(node)) { + // Check if node already on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + *is_candidate = true; return Status::OK(); } @@ -232,17 +278,6 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, return Status::OK(); } - // Check all inputs are Host friendly. - for (const GraphView::OutputPort& fanin : - graph.GetFanins(node, /*include_controlling_nodes=*/false)) { - bool fanin_candidate = false; - TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( - graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); - if (!fanin_candidate) { - return Status::OK(); - } - } - // Check all outputs are Host friendly. if (!properties->has_properties()) { // This is an expensive call, call it lazily. @@ -255,16 +290,42 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, } } + // Check all inputs are Host friendly. + for (const GraphView::OutputPort& fanin : + graph.GetFanins(node, /*include_controlling_nodes=*/false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, + op_device_outport_pinned_to_host_cache, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); + } + } + *is_candidate = true; return Status::OK(); } -string TryFindHostDevice(const gtl::FlatSet<string>& devices, - bool has_device_cpu, const string& device) { +bool IsTPUGraphDef(const GraphDef& def) { + for (const auto& node : def.node()) { + if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || + node.op() == "TPUPartitionedCall") { + return true; + } + } + return false; +} +} // end namespace + +// Tries to swap `device` to a Host device from `devices`. Returns true iff +// there was a swap. +bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices, + bool has_device_cpu, string* device) { // Force this node onto the CPU. - if (device.empty() && has_device_cpu) { - return "/device:CPU:0"; - } else if (str_util::StrContains(device, DEVICE_GPU)) { + if (device->empty() && has_device_cpu) { + *device = "/device:CPU:0"; + return true; + } else if (str_util::StrContains(*device, DEVICE_GPU)) { // Sometimes the cluster can have: // devices = {"/device:CPU:0", "/device:XLA_GPU:0"} // and we need to handle them properly. @@ -272,27 +333,19 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices, {std::pair<string, string>("GPU", "CPU:0"), std::pair<string, string>("/device", "/device:CPU:0")}) { const string device_host = - strings::StrCat(device.substr(0, device.rfind(device_match.first)), + strings::StrCat(device->substr(0, device->rfind(device_match.first)), device_match.second); if (devices.find(device_host) != devices.end()) { - return device_host; + *device = device_host; + return true; } } } - // We couldn't find an appropriate Host device, return original device. - return device; -} - -bool IsTPUGraphDef(const GraphDef& def) { - for (const auto& node : def.node()) { - if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || - node.op() == "TPUPartitionedCall") { - return true; - } - } + // We couldn't find an appropriate Host device, return false. return false; } + } // end namespace internal Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -324,20 +377,26 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // All the Const nodes, and their original devices in topological order. std::vector<std::pair<NodeDef*, string>> const_nodes; + // Cache to map {op, device, port} -> bool on whether it is pinned to host. + internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache; + internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache; + for (auto& node : *optimized_graph->mutable_node()) { bool is_candidate = false; - TF_RETURN_IF_ERROR( - internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate)); + TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate( + graph, &properties, node, &op_device_outport_pinned_to_host_cache, + &is_candidate)); if (!is_candidate) { continue; } - if (IsConstant(node)) { - const_nodes.emplace_back(&node, node.device()); + const string original_device = node.device(); + const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu, + node.mutable_device()); + // Keep track of all Const nodes that we swapped. + if (swapped && IsConstant(node)) { + const_nodes.emplace_back(&node, original_device); } - // Try and swap the device to Host. - node.set_device( - internal::TryFindHostDevice(devices, has_device_cpu, node.device())); } // Traverse all `const_nodes`, and map them back to GPU greedily. @@ -349,8 +408,9 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // this node back onto the original device. for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) { // The consumer is not Host friendly, swap it back to the original device. - if (!internal::IsNodeInputPortHostFriendly(*fanout.node, - fanout.port_id)) { + if (!internal::IsNodeInputPortHostFriendly( + *fanout.node, fanout.port_id, + &op_device_inport_pinned_to_host_cache)) { node->set_device(device); break; } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h index d557a03463..bed4a9ef95 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h @@ -26,8 +26,8 @@ namespace tensorflow { namespace grappler { namespace internal { // Try and find an appropriate Host device in `devices` given `device`. -string TryFindHostDevice(const gtl::FlatSet<string>& devices, - bool has_device_cpu, const string& device); +bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices, + bool has_device_cpu, string* device); } // end namespace internal // Optimize TensorFlow ops that should be swapped into the CPU to avoid diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc index 7c64529441..9bb030b220 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -28,30 +28,60 @@ namespace { class PinToHostOptimizerTest : public GrapplerTest {}; -TEST_F(PinToHostOptimizerTest, TryFindHostDevice) { +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) { gtl::FlatSet<string> devices = {}; - EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC")); - - devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"), - "/device:CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"), - "/device:CPU:0"); - - devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), - "/device:XLA_CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), - "/device:XLA_CPU:0"); - - devices = {"/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), - "/device:XLA_GPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), - "/device:XLA_GPU:*"); + + string device = "ABC"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "ABC"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) { + gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); + + device = "/device:XLA_GPU:0"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) { + gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_TRUE(device.empty()); + + device = "/device:XLA_GPU:0"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_CPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_CPU:0"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) { + gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_TRUE(device.empty()); + + device = "/device:XLA_GPU:0"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_GPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_GPU:*"); } TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) { diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 451f8c1a6c..37c1c54786 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -45,6 +45,16 @@ cc_library( ], ) +tf_cc_test( + name = "dataset_utils_test", + srcs = ["dataset_utils_test.cc"], + deps = [ + ":dataset_utils", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "captured_function", srcs = ["captured_function.cc"], @@ -205,6 +215,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -232,6 +243,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -245,6 +257,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -285,6 +298,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", ":parallel_map_iterator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index e10833f525..a40f7f2146 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -15,10 +15,57 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { namespace data { +Status ComputeShortCircuitIndices(OpKernelContext* ctx, + const NameAttrList& func, + std::vector<int>* indices) { + FunctionLibraryRuntime::Handle fn_handle; + TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( + func.name(), AttrSlice(&func.attr()), &fn_handle)); + auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() { + Status s = ctx->function_library()->ReleaseHandle(fn_handle); + if (!s.ok()) { + LOG(WARNING) << "Failed to release handle: " << s.error_message(); + } + }); + + const FunctionBody* fn_body = + ctx->function_library()->GetFunctionBody(fn_handle); + indices->resize(fn_body->ret_nodes.size()); + for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) { + Node* ret_node = fn_body->ret_nodes[i]; + Node* ret_input_node; + TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node)); + if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) { + TF_RETURN_IF_ERROR( + GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i]))); + } else { + indices->clear(); + break; + } + } + return Status::OK(); +} + +std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) { + std::map<int, int> last_use; + for (size_t i = 0; i < indices.size(); ++i) { + last_use[indices[i]] = i; + } + std::vector<bool> can_move; + can_move.resize(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + can_move[i] = last_use[indices[i]] == i; + } + return can_move; +} + Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector<Tensor>& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 6ec1350cd4..d777062293 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -22,6 +22,26 @@ limitations under the License. namespace tensorflow { namespace data { +// This method is used to determine whether we can short-circuit the evaluation +// of the user-defined function `func`. Short-circuting is possible if every +// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) = +// (y,x)`, or `f(x) = (x,x)`). +// +// If short-circuiting is possible, the method stores the mapping from output +// indices to input indices in `indices`. Otherwise, `indices` will be empty. +// +// Returns non-ok status if analysis of the function fails. +// +// TODO(jsimsa): Extend this to support constants as well. +Status ComputeShortCircuitIndices(OpKernelContext* ctx, + const NameAttrList& func, + std::vector<int>* indices); + +// Given a vector that maps output indices to input indices, return a vector +// that identifies for which output indices can we move the input (assuming +// output indices are processed left to right). +std::vector<bool> ComputeMoveVector(const std::vector<int>& indices); + Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector<Tensor>& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc new file mode 100644 index 0000000000..43295b8ebb --- /dev/null +++ b/tensorflow/core/kernels/data/dataset_utils_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_utils.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace { + +TEST(DatasetUtils, ComputeMoveVector) { + struct TestCase { + std::vector<int> indices; + std::vector<bool> expected; + }; + + TestCase test_cases[] = { + TestCase{{}, {}}, + TestCase{{1}, {true}}, + TestCase{{1, 1}, {false, true}}, + TestCase{{1, 2}, {true, true}}, + TestCase{{1, 1, 2}, {false, true, true}}, + TestCase{{1, 2, 2}, {true, false, true}}, + }; + + for (auto& test_case : test_cases) { + EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices)); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index 00884314a9..be7d182a1f 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -18,9 +18,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -31,67 +33,84 @@ namespace { class FilterDatasetOp : public UnaryDatasetOpKernel { public: + using FilterIteratorPredicate = + std::function<Status(IteratorContext*, std::vector<Tensor>, bool*)>; + explicit FilterDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - FunctionLibraryRuntime::Handle pred_handle; - OP_REQUIRES_OK(ctx, - ctx->function_library()->Instantiate( - func_.name(), AttrSlice(&func_.attr()), &pred_handle)); - auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() { - OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle)); - }); - - const FunctionBody* pred_body = - ctx->function_library()->GetFunctionBody(pred_handle); - OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1, - errors::InvalidArgument( - "predicate function must have a single return value.")); - Node* ret_node = pred_body->ret_nodes[0]; - Node* ret_input_node; - OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node)); - std::unique_ptr<CapturedFunction> captured_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - if (ret_input_node->def().op() == "_Arg") { - int32 index = -1; - OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index)); - *output = new FilterTensorDataset(ctx, input, func_, - std::move(captured_func), index); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + OP_REQUIRES(ctx, indices.size() <= 1, + errors::InvalidArgument( + "predicate function has more than one return value.")); + + FilterIteratorPredicate filter_pred; + if (indices.empty()) { + CapturedFunction* raw_captured_func = captured_func.get(); + filter_pred = [raw_captured_func](IteratorContext* ctx, + const std::vector<Tensor>& args, + bool* out_matched) { + std::vector<Tensor> result; + TF_RETURN_IF_ERROR( + raw_captured_func->RunWithBorrowedArgs(ctx, args, &result)); + + if (result.size() != 1 || result[0].dtype() != DT_BOOL || + result[0].NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = result[0].scalar<bool>()(); + return Status::OK(); + }; } else { - *output = new FilterFunctionDataset(ctx, input, func_, - std::move(captured_func)); + filter_pred = [indices](IteratorContext* ctx, + const std::vector<Tensor>& args, + bool* out_matched) { + const Tensor& predicate = args[indices[0]]; + if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = predicate.scalar<bool>()(); + return Status::OK(); + }; } + + *output = new Dataset(ctx, input, func_, std::move(captured_func), + std::move(filter_pred)); } private: - const int graph_def_version_; - - class FilterDatasetBase : public DatasetBase { + class Dataset : public DatasetBase { public: - FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr<CapturedFunction> captured_func) + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr<CapturedFunction> captured_func, + FilterIteratorPredicate filter_pred) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), - captured_func_(std::move(captured_func)) { + captured_func_(std::move(captured_func)), + filter_pred_(std::move(filter_pred)) { input_->Ref(); } - ~FilterDatasetBase() override { input_->Unref(); } + ~Dataset() override { input_->Unref(); } std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Filter")})); + return MakeUnique<Iterator>( + Iterator::Params{this, strings::StrCat(prefix, "::Filter")}, + filter_pred_); } const DataTypeVector& output_dtypes() const override { @@ -133,17 +152,15 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - virtual Status EvaluatePredicate(IteratorContext* ctx, - const std::vector<Tensor>& element, - bool* out_matched) const = 0; - private: - class Iterator : public DatasetIterator<FilterDatasetBase> { + class Iterator : public DatasetIterator<Dataset> { public: - explicit Iterator(const Params& params) - : DatasetIterator<FilterDatasetBase>(params), + explicit Iterator(const Params& params, + FilterIteratorPredicate filter_pred) + : DatasetIterator<Dataset>(params), filtered_elements_(0), - dropped_elements_(0) { + dropped_elements_(0), + filter_pred_(std::move(filter_pred)) { std::vector<string> components = str_util::Split(params.prefix, "::", str_util::SkipEmpty()); prefix_end_ = components.back(); @@ -180,8 +197,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - TF_RETURN_IF_ERROR( - dataset()->EvaluatePredicate(ctx, *out_tensors, &matched)); + TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched)); if (!matched) { // Clear the output tensor list since it didn't match. out_tensors->clear(); @@ -251,64 +267,14 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); int64 filtered_elements_ GUARDED_BY(mu_); int64 dropped_elements_ GUARDED_BY(mu_); + const FilterIteratorPredicate filter_pred_; string prefix_end_; }; const DatasetBase* const input_; const NameAttrList func_; - - protected: const std::unique_ptr<CapturedFunction> captured_func_; - }; - - class FilterFunctionDataset : public FilterDatasetBase { - public: - using FilterDatasetBase::FilterDatasetBase; - - protected: - Status EvaluatePredicate(IteratorContext* ctx, - const std::vector<Tensor>& element, - bool* out_matched) const override { - // TODO(mrry): Avoid blocking a threadpool thread. We will need to - // stack-rip the iterators and use async kernels. - std::vector<Tensor> result; - TF_RETURN_IF_ERROR( - captured_func_->RunWithBorrowedArgs(ctx, element, &result)); - - if (result.size() != 1 || result[0].dtype() != DT_BOOL || - result[0].NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = result[0].scalar<bool>()(); - return Status::OK(); - } - }; - - class FilterTensorDataset : public FilterDatasetBase { - public: - FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr<CapturedFunction> captured_func, - int32 index) - : FilterDatasetBase(ctx, input, func, std::move(captured_func)), - index_(index) {} - - protected: - Status EvaluatePredicate(IteratorContext* ctx, - const std::vector<Tensor>& element, - bool* out_matched) const override { - const Tensor& predicate = element[index_]; - if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = predicate.scalar<bool>()(); - return Status::OK(); - } - - private: - const int32 index_; + const FilterIteratorPredicate filter_pred_; }; private: diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 7a833668ac..8acd6cc724 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -16,10 +16,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/renamed_device.h" -#include "tensorflow/core/common_runtime/threadpool_device.h" #include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_op_registry.h" @@ -27,13 +25,11 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/ops_util.h" -#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/public/session_options.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index bf08970560..0fb721cd7c 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/inplace_ops_functor.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -41,6 +43,10 @@ namespace { // transformation more robust. class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: + using MapAndBatchIteratorFunction = + std::function<void(IteratorContext*, const string&, std::vector<Tensor>, + std::shared_ptr<std::vector<Tensor>>, StatusCallback)>; + explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { @@ -91,31 +97,73 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - *output = new Dataset(ctx, input, batch_size, num_parallel_calls, - drop_remainder, output_types_, output_shapes_, func_, - std::move(captured_func), &ctx->eigen_cpu_device()); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + MapAndBatchIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func]( + IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::shared_ptr<std::vector<Tensor>> out_tensors, + StatusCallback done) { + raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(), + std::move(done), prefix); + }; + } else { + std::vector<bool> can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::shared_ptr<std::vector<Tensor>> out_tensors, + StatusCallback done) { + const std::vector<Tensor>& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + done(Status::OK()); + }; + } + + *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls, + drop_remainder, output_types_, output_shapes_, + std::move(captured_func), &ctx->eigen_cpu_device(), + std::move(map_func)); } private: class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, int64 batch_size, int64 num_parallel_calls, bool drop_remainder, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, - const NameAttrList& func, std::unique_ptr<CapturedFunction> captured_func, - const Eigen::ThreadPoolDevice* device) + const Eigen::ThreadPoolDevice* device, + MapAndBatchIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), + func_(func), batch_size_(batch_size), num_parallel_calls_(num_parallel_calls), drop_remainder_(drop_remainder), output_types_(output_types), output_shapes_(output_shapes), - map_fn_(func), captured_func_(std::move(captured_func)), - device_(device) { + device_(device), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -123,8 +171,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")})); + return MakeUnique<Iterator>( + Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")}, + map_func_); } const DataTypeVector& output_dtypes() const override { @@ -143,7 +192,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; @@ -165,7 +214,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { other_arguments_types.emplace_back(t.dtype()); } AttrValue f; - b->BuildAttrValue(map_fn_, &f); + b->BuildAttrValue(func_, &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -185,12 +234,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator<Dataset> { public: - explicit Iterator(const Params& params) + explicit Iterator(const Params& params, + MapAndBatchIteratorFunction map_func) : DatasetIterator<Dataset>(params), mu_(std::make_shared<mutex>()), cond_var_(std::make_shared<condition_variable>()), num_parallel_calls_(std::make_shared<model::SharedState>( - params.dataset->num_parallel_calls_, mu_, cond_var_)) {} + params.dataset->num_parallel_calls_, mu_, cond_var_)), + map_func_(std::move(map_func)) {} ~Iterator() override { mutex_lock l(*mu_); @@ -297,44 +348,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { int64 num_calls; // access guarded by owner's mutex }; - void Callback(const std::shared_ptr<IteratorContext>& ctx, - const std::shared_ptr<BatchResult>& result, - const std::shared_ptr<std::vector<Tensor>>& return_values, - int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) { - result->UpdateStatus(status); - if (status.ok()) { - EnsureOutputAllocated(ctx, result, return_values); - for (size_t i = 0; i < return_values->size(); ++i) { - const Tensor& tensor = return_values->at(i); - Tensor* batch = &(result->output)[i]; - if (tensor.NumElements() != - (batch->NumElements() / batch->dim_size(0))) { - TensorShape batch_shape = batch->shape(); - batch_shape.RemoveDim(0); - result->UpdateStatus(errors::InvalidArgument( - "Cannot add tensor to the batch: number of elements does not " - "match. Shapes are: [tensor]: ", - tensor.shape().DebugString(), - ", [batch]: ", batch_shape.DebugString())); - break; - } - // TODO(mrry): Add a version of DoParallelConcat that allows us to - // move `tensor` where possible, to speed up string tensor batching. - Status copy_status = ::tensorflow::functor::DoParallelConcat( - *dataset()->device_, tensor, offset, batch); - if (!copy_status.ok()) { - result->UpdateStatus(copy_status); - break; - } - } - { - mutex_lock l(result->mu); - result->num_elements++; - } - } - CallCompleted(result); - } - void CallCompleted(const std::shared_ptr<BatchResult>& result) LOCKS_EXCLUDED(*mu_) { mutex_lock l(*mu_); @@ -363,21 +376,48 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return; } - // Call `captured_func_(input_element)`, using `Callback` to store the - // result in `result`. - (*ctx->runner())(std::bind( - [this, result, offset](std::shared_ptr<IteratorContext> ctx, - std::vector<Tensor> input_element) { - std::shared_ptr<std::vector<Tensor>> return_values( - new std::vector<Tensor>()); - dataset()->captured_func_->RunAsync( - ctx.get(), std::move(input_element), return_values.get(), - [this, ctx, result, return_values, offset](Status status) { - Callback(ctx, result, return_values, offset, status); - }, - prefix()); - }, - ctx, std::move(input_element))); + std::shared_ptr<std::vector<Tensor>> return_values = + std::make_shared<std::vector<Tensor>>(); + auto done = [this, ctx, result, return_values, offset](Status status) { + result->UpdateStatus(status); + if (status.ok()) { + EnsureOutputAllocated(ctx, result, return_values); + for (size_t i = 0; i < return_values->size(); ++i) { + const Tensor& tensor = return_values->at(i); + Tensor* batch = &(result->output)[i]; + if (tensor.NumElements() != + (batch->NumElements() / batch->dim_size(0))) { + TensorShape batch_shape = batch->shape(); + batch_shape.RemoveDim(0); + result->UpdateStatus(errors::InvalidArgument( + "Cannot add tensor to the batch: number of elements does " + "not match. Shapes are: [tensor]: ", + tensor.shape().DebugString(), + ", [batch]: ", batch_shape.DebugString())); + break; + } + // TODO(mrry): Add a version of DoParallelConcat that allows us to + // move `tensor` where possible, to speed up string tensor + // batching. + Status copy_status = ::tensorflow::functor::DoParallelConcat( + *dataset()->device_, tensor, offset, batch); + if (!copy_status.ok()) { + result->UpdateStatus(copy_status); + break; + } + } + { + mutex_lock l(result->mu); + result->num_elements++; + } + } + CallCompleted(result); + }; + + // Apply the map function on `input_element`, storing the result in + // `return_values`, and invoking `done` when finished. + map_func_(ctx.get(), prefix(), std::move(input_element), + std::move(return_values), std::move(done)); } Status CopyPartialBatch(Tensor* output, const Tensor& value, @@ -404,10 +444,11 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); - runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", - std::bind(&Iterator::RunnerThread, this, ctx_copy))); + auto ctx_copy = std::make_shared<IteratorContext>(*ctx); + runner_thread_ = + MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread"); + runner_thread_->Schedule( + std::bind(&Iterator::RunnerThread, this, ctx_copy)); } } @@ -509,8 +550,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { while (!busy()) { if (call_counter_ % dataset()->batch_size_ == 0) { - batch_results_.emplace_back( - new BatchResult(dataset()->batch_size_)); + batch_results_.push_back( + std::make_shared<BatchResult>(dataset()->batch_size_)); } int64 offset = call_counter_++ % dataset()->batch_size_; new_calls.emplace_back(batch_results_.back(), offset); @@ -527,7 +568,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - batch_results_.emplace_back(new BatchResult(dataset()->batch_size_)); + batch_results_.push_back( + std::make_shared<BatchResult>(dataset()->batch_size_)); std::shared_ptr<BatchResult> result = batch_results_.back(); string prefix = strings::StrCat("batch_results_", index); mutex_lock l(result->mu); @@ -653,6 +695,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const std::shared_ptr<condition_variable> cond_var_; // Identifies the maximum number of parallel calls. const std::shared_ptr<model::SharedState> num_parallel_calls_; + const MapAndBatchIteratorFunction map_func_; + // Counts the number of outstanding calls for this batch. int64 num_calls_ GUARDED_BY(*mu_) = 0; // Counts the total number of calls. @@ -660,7 +704,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> input_impl_; // Buffer for storing the (intermediate) batch results. std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_); - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_); bool cancelled_ GUARDED_BY(*mu_) = false; }; @@ -671,9 +715,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const bool drop_remainder_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; - const NameAttrList map_fn_; const std::unique_ptr<CapturedFunction> captured_func_; const Eigen::ThreadPoolDevice* device_; // not owned + const MapAndBatchIteratorFunction map_func_; }; const int op_version_; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index f112e1dc43..6b6ffabf4f 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -28,6 +30,9 @@ namespace { class MapDatasetOp : public UnaryDatasetOpKernel { public: + using MapIteratorFunction = std::function<Status( + IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>; + explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -43,8 +48,42 @@ class MapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + MapIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func](IteratorContext* ctx, + std::vector<Tensor> args, + std::vector<Tensor>* out_tensors) { + return raw_captured_func->Run(ctx, std::move(args), out_tensors); + }; + } else { + std::vector<bool> can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, std::vector<Tensor> args, + std::vector<Tensor>* out_tensors) { + const std::vector<Tensor>& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + return Status::OK(); + }; + } + *output = new Dataset(ctx, input, func_, std::move(captured_func), - output_types_, output_shapes_); + output_types_, output_shapes_, std::move(map_func)); } private: @@ -54,13 +93,15 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const NameAttrList& func, std::unique_ptr<CapturedFunction> captured_func, const DataTypeVector& output_types, - const std::vector<PartialTensorShape>& output_shapes) + const std::vector<PartialTensorShape>& output_shapes, + MapIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), captured_func_(std::move(captured_func)), output_types_(output_types), - output_shapes_(output_shapes) { + output_shapes_(output_shapes), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -68,8 +109,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Map")})); + return MakeUnique<Iterator>( + Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_); } const DataTypeVector& output_dtypes() const override { @@ -116,8 +157,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator<Dataset> { public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} + explicit Iterator(const Params& params, MapIteratorFunction map_func) + : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {} Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( @@ -139,10 +180,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - // TODO(mrry): Avoid blocking a threadpool thread. We will need to - // stack-rip the iterators and use async kernels. - Status s = - dataset()->captured_func_->Run(ctx, std::move(args), out_tensors); + Status s = map_func_(ctx, args, out_tensors); if (errors::IsOutOfRange(s)) { // `f` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. @@ -167,6 +205,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: std::unique_ptr<IteratorBase> input_impl_; + const MapIteratorFunction map_func_; }; const DatasetBase* const input_; @@ -174,6 +213,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const std::unique_ptr<CapturedFunction> captured_func_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; + const MapIteratorFunction map_func_; }; DataTypeVector output_types_; diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 9aa505f4f1..859df57962 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -126,9 +127,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!optimize_thread_) { std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); - optimize_thread_.reset(ctx->env()->StartThread( - {}, "optimize_thread", - [this, new_ctx]() { OptimizeThread(new_ctx); })); + optimize_thread_ = + MakeUnique<BackgroundWorker>(ctx->env(), "optimize_thread"); + optimize_thread_->Schedule( + [this, new_ctx]() { OptimizeThread(new_ctx); }); } return Status::OK(); } @@ -167,7 +169,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { mutex mu_; condition_variable cond_var_; std::shared_ptr<model::Model> model_; - std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_); + std::unique_ptr<BackgroundWorker> optimize_thread_ GUARDED_BY(mu_); bool cancelled_ GUARDED_BY(mu_) = false; std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 6b6b3d6ab9..9c836b836e 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -481,9 +482,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { worker_threads_.reserve(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back(ctx->env()->StartThread( - {}, "worker_thread", - [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); + worker_threads_.emplace_back( + MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread")); + worker_threads_.back()->Schedule( + [this, new_ctx, i]() { WorkerThread(new_ctx, i); }); } } return Status::OK(); @@ -580,9 +582,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } workers_[i].SetInputs(s, std::move(args)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back(ctx->env()->StartThread( - {}, "worker_thread", - [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); + worker_threads_.emplace_back( + MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread")); + worker_threads_.back()->Schedule( + [this, new_ctx, i]() { WorkerThread(new_ctx, i); }); if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); } else { @@ -1047,7 +1050,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // The worker threads. This must be last to ensure the // threads have exited before any other members are deallocated. // TODO(b/65178177): Avoid allocating additional threads. - std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_); + std::vector<std::unique_ptr<BackgroundWorker>> worker_threads_ + GUARDED_BY(mu_); }; const DatasetBase* const input_; @@ -1389,9 +1393,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); - runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", - [this, new_ctx]() { RunnerThread(new_ctx); })); + runner_thread_ = + MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread"); + runner_thread_->Schedule( + [this, new_ctx]() { RunnerThread(new_ctx); }); } } @@ -1645,7 +1650,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr<thread::ThreadPool> thread_pool_; - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_); // Identifies whether background activity should be cancelled. bool cancelled_ GUARDED_BY(*mu_) = false; diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 6abe6c8338..3a14924fba 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/parallel_map_iterator.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/random/random.h" @@ -56,9 +57,55 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + ParallelMapIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::vector<Tensor>* out_tensors, + StatusCallback done) { + raw_captured_func->RunAsync(ctx, std::move(args), out_tensors, + std::move(done), prefix); + }; + if (!use_inter_op_parallelism_) { + map_func = [map_func](IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::vector<Tensor>* out_tensors, + StatusCallback done) { + (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args), + out_tensors, std::move(done))); + }; + } + } else { + std::vector<bool> can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, std::vector<Tensor>* out_tensors, + StatusCallback done) { + const std::vector<Tensor>& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + done(Status::OK()); + }; + } + *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_, output_shapes_, use_inter_op_parallelism_, - std::move(captured_func)); + std::move(captured_func), std::move(map_func)); } private: @@ -69,7 +116,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, bool use_inter_op_parallelism, - std::unique_ptr<CapturedFunction> captured_func) + std::unique_ptr<CapturedFunction> captured_func, + ParallelMapIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), @@ -77,7 +125,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { output_types_(output_types), output_shapes_(output_shapes), use_inter_op_parallelism_(use_inter_op_parallelism), - captured_func_(std::move(captured_func)) { + captured_func_(std::move(captured_func)), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -89,26 +138,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return captured_func_->Instantiate(ctx); }; - const string& new_prefix = strings::StrCat(prefix, "::ParallelMap"); - ParallelMapIteratorFunction map_func = - [this, new_prefix](IteratorContext* ctx, - std::vector<Tensor> input_element, - std::vector<Tensor>* result, StatusCallback done) { - captured_func_->RunAsync(ctx, std::move(input_element), result, - std::move(done), new_prefix); - }; - if (!use_inter_op_parallelism_) { - map_func = [map_func]( - IteratorContext* ctx, std::vector<Tensor> input_element, - std::vector<Tensor>* result, StatusCallback done) { - (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element), - result, std::move(done))); - }; - } - - return NewParallelMapIterator({this, new_prefix}, input_, - std::move(init_func), std::move(map_func), - num_parallel_calls_); + return NewParallelMapIterator( + {this, strings::StrCat(prefix, "::ParallelMap")}, input_, + std::move(init_func), map_func_, num_parallel_calls_); } const DataTypeVector& output_dtypes() const override { @@ -176,6 +208,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; const bool use_inter_op_parallelism_; const std::unique_ptr<CapturedFunction> captured_func_; + const ParallelMapIteratorFunction map_func_; }; DataTypeVector output_types_; diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 13bd4b6036..e69274e4f2 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -179,10 +180,11 @@ class ParallelMapIterator : public DatasetBaseIterator { void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); - runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", - std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); + auto ctx_copy = std::make_shared<IteratorContext>(*ctx); + runner_thread_ = + MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread"); + runner_thread_->Schedule( + std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)); } } @@ -208,15 +210,15 @@ class ParallelMapIterator : public DatasetBaseIterator { return; } - // Call `func_(input_element)`, store the result in `result->return_values`, - // and notify `result->notification` to unblock a consumer. auto done = [this, result](Status status) { result->status.Update(status); CallCompleted(result); }; - map_func_(ctx.get(), std::move(input_element), &result->return_values, - std::move(done)); + // Apply the map function on `input_element`, storing the result in + // `result->return_values`, and invoking `done` when finished. + map_func_(ctx.get(), prefix(), std::move(input_element), + &result->return_values, std::move(done)); } Status ProcessResult(const std::shared_ptr<InvocationResult>& result, @@ -330,7 +332,7 @@ class ParallelMapIterator : public DatasetBaseIterator { // Buffer for storing the invocation results. std::deque<std::shared_ptr<InvocationResult>> invocation_results_ GUARDED_BY(*mu_); - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_); bool cancelled_ GUARDED_BY(*mu_) = false; }; @@ -349,9 +351,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator( const DatasetBase* input_dataset, std::function<Status(IteratorContext*)> init_func, ParallelMapIteratorFunction map_func, int32 num_parallel_calls) { - return std::unique_ptr<IteratorBase>( - new ParallelMapIterator(params, input_dataset, std::move(init_func), - std::move(map_func), num_parallel_calls)); + return MakeUnique<ParallelMapIterator>( + params, input_dataset, std::move(init_func), std::move(map_func), + num_parallel_calls); } } // namespace data diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h index dc26c5cf25..813f13c9e4 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.h +++ b/tensorflow/core/kernels/data/parallel_map_iterator.h @@ -30,7 +30,7 @@ namespace data { // 3. A `std::vector<Tensor>*` to which the function will write the result. // 4. A `StatusCallback` that should be invoked when the function is complete. using ParallelMapIteratorFunction = - std::function<void(IteratorContext*, std::vector<Tensor>, + std::function<void(IteratorContext*, const string&, std::vector<Tensor>, std::vector<Tensor>*, StatusCallback)>; // Returns a new iterator that applies `map_func` to the elements of diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index 1d1a717062..7de5ea8860 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -182,7 +182,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - auto map_fn = [this](IteratorContext* ctx, + auto map_fn = [this](IteratorContext* ctx, const string& prefix, std::vector<Tensor> input_element, std::vector<Tensor>* result, StatusCallback done) { (*ctx->runner())([this, ctx, input_element, result, done]() { diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 754ed772db..e9c38eb8a0 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -256,10 +257,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status EnsurePrefetchThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!prefetch_thread_) { + prefetch_thread_ = + MakeUnique<BackgroundWorker>(ctx->env(), "prefetch_thread"); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); - prefetch_thread_.reset(ctx->env()->StartThread( - {}, "prefetch_thread", - [this, new_ctx]() { PrefetchThread(new_ctx); })); + prefetch_thread_->Schedule( + [this, new_ctx]() { PrefetchThread(new_ctx); }); } return Status::OK(); } @@ -363,7 +365,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { string prefix_end_; PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); std::deque<BufferElement> buffer_ GUARDED_BY(mu_); - std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_); + std::unique_ptr<BackgroundWorker> prefetch_thread_ GUARDED_BY(mu_); bool cancelled_ GUARDED_BY(mu_) = false; bool prefetch_thread_finished_ GUARDED_BY(mu_) = false; }; diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 66466d6a36..9f54c381a9 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -485,7 +485,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { int64 buffer_size, int64 seed, int64 seed2, int64 count) : ShuffleDatasetBase(ctx, input, buffer_size, count), seed_(seed), - seed2_(seed) {} + seed2_(seed2) {} string DebugString() const override { return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_, diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc index 3f76695bb1..7bb2077b62 100644 --- a/tensorflow/core/kernels/data/writer_ops.cc +++ b/tensorflow/core/kernels/data/writer_ops.cc @@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel { public: explicit ToTFRecordOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - thread_pool_(new thread::ThreadPool( - ctx->env(), ThreadOptions(), - strings::StrCat("to_tf_record__op_", SanitizeThreadSuffix(name())), - 1 /* num_threads */, false /* low_latency_hint */)) {} + background_worker_( + ctx->env(), + strings::StrCat("to_tf_record_op_", SanitizeThreadSuffix(name()))) { + } template <typename T> Status ParseScalarArgument(OpKernelContext* ctx, @@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel { // The call to `iterator->GetNext()` may block and depend on an // inter-op thread pool thread, so we issue the call from the // owned thread pool. - thread_pool_->Schedule([this, ctx, done]() { + background_worker_.Schedule([this, ctx, done]() { string filename; OP_REQUIRES_OK_ASYNC( ctx, ParseScalarArgument<string>(ctx, "filename", &filename), done); @@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel { } private: - std::unique_ptr<thread::ThreadPool> thread_pool_; + BackgroundWorker background_worker_; }; REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 04a53697c0..3810d817ca 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -489,13 +489,15 @@ class RandomGammaOp : public OpKernel { Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ RandomGammaOp<TYPE>) -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint<IntType>("Tout"), \ +#define REGISTER_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<IntType>("Tout"), \ RandomUniformIntOp<CPUDevice, IntType>); TF_CALL_half(REGISTER); @@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT); random::TruncatedNormalDistribution< \ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>); -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint<int32>("T") \ - .TypeConstraint<IntType>("Tout"), \ +#define REGISTER_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<int32>("T") \ + .TypeConstraint<IntType>("Tout"), \ RandomUniformIntOp<GPUDevice, IntType>); TF_CALL_half(REGISTER); diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index 173fea37ed..e67695d54a 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -33,19 +33,25 @@ typedef Eigen::GpuDevice GPUDevice; typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL -#define REGISTER_RELU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - ReluOp<CPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - ReluGradOp<CPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - Relu6Op<CPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - Relu6GradOp<CPUDevice, type>) +#define REGISTER_RELU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + ReluOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + ReluGradOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + Relu6Op<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + Relu6GradOp<CPUDevice, type>) \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + LeakyReluOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + LeakyReluGradOp<CPUDevice, type>); TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS); #undef REGISTER_RELU_KERNELS @@ -99,6 +105,19 @@ namespace functor { extern template struct Relu6Grad<GPUDevice, T>; \ \ template <> \ + void LeakyRelu<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstTensor features, T alpha, \ + typename TTypes<T>::Tensor activations); \ + extern template struct LeakyRelu<GPUDevice, T>; \ + \ + template <> \ + void LeakyReluGrad<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \ + typename TTypes<T>::ConstTensor features, T alpha, \ + typename TTypes<T>::Tensor backprops); \ + extern template struct LeakyReluGrad<GPUDevice, T>; \ + \ + template <> \ void Elu<GPUDevice, T>::operator()(const GPUDevice& d, \ typename TTypes<T>::ConstTensor features, \ typename TTypes<T>::Tensor activations); \ @@ -134,30 +153,36 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor // Registration of the GPU implementations. -#define REGISTER_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - ReluOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - ReluGradOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - Relu6Op<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - Relu6GradOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - EluOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - EluGradOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - SeluOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + ReluOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + ReluGradOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + Relu6Op<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + Relu6GradOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + LeakyReluOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + LeakyReluGradOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + EluOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + EluGradOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + SeluOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ SeluGradOp<GPUDevice, type>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); @@ -188,30 +213,36 @@ REGISTER_KERNEL_BUILDER( #ifdef TENSORFLOW_USE_SYCL // Registration of the GPU implementations. -#define REGISTER_SYCL_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - ReluOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - ReluGradOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - Relu6Op<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - Relu6GradOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - EluOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - EluGradOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - SeluOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + ReluOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + ReluGradOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + Relu6Op<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + Relu6GradOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + LeakyReluOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + LeakyReluGradOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + EluOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + EluGradOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + SeluOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ SeluGradOp<SYCLDevice, type>) TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index 4775deeb61..a4638c70c2 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -132,6 +132,67 @@ void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, } template <typename Device, typename T> +class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> { + public: + explicit LeakyReluOp(OpKernelConstruction* context) + : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::LeakyRelu<Device, T> functor; + functor(context->eigen_device<Device>(), input.flat<T>(), alpha_, + output->flat<T>()); + } + + private: + T alpha_; +}; + +template <typename Device, typename T> +class LeakyReluGradOp + : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> { + public: + explicit LeakyReluGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, T alpha, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): either the inputs that were passed to LeakyReluOp(), or its + // outputs (using either one yields the same result here). + // OUTPUT: + // gradients to backprop + template <int NDIMS> + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, alpha_, output); + } + + private: + T alpha_; +}; + +template <typename Device, typename T> +void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, + const Tensor& a, T alpha, + Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::LeakyReluGrad<Device, T> functor; + functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha, + output->flat<T>()); +}; + +template <typename Device, typename T> class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> { public: using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp; diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index e564da335a..f917142a12 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -91,6 +91,36 @@ struct Relu6Grad { } }; +// Functor used by LeakyReluOp to do the computations. +template <typename Device, typename T> +struct LeakyRelu { + // Computes LeakyRelu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes<T>::ConstTensor features, + T alpha, typename TTypes<T>::Tensor activations) { + activations.device(d) = features.cwiseMax(features * alpha); + } +}; + +// Functor used by LeakyReluGradOp to do the computations. +template <typename Device, typename T> +struct LeakyReluGrad { + // Computes LeakyReluGrad backprops. + // + // gradients: gradients backpropagated to the LeakyRelu op. + // features: either the inputs that were passed to the LeakyRelu or, or its + // outputs (using either one yields the same result here). + // backprops: gradients to backpropagate to the LeakyRelu inputs. + void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients, + typename TTypes<T>::ConstTensor features, T alpha, + typename TTypes<T>::Tensor backprops) { + backprops.device(d) = + (features > static_cast<T>(0)).select(gradients, gradients * alpha); + } +}; + // Functor used by EluOp to do the computations. template <typename Device, typename T> struct Elu { diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc index b9391517c1..dd5f9495e2 100644 --- a/tensorflow/core/kernels/relu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -145,14 +145,16 @@ struct Relu<Device, qint8> { } // namespace functor // Definition of the GPU implementations declared in relu_op.cc. -#define DEFINE_GPU_KERNELS(T) \ - template struct functor::Relu<GPUDevice, T>; \ - template struct functor::ReluGrad<GPUDevice, T>; \ - template struct functor::Relu6<GPUDevice, T>; \ - template struct functor::Relu6Grad<GPUDevice, T>; \ - template struct functor::Elu<GPUDevice, T>; \ - template struct functor::EluGrad<GPUDevice, T>; \ - template struct functor::Selu<GPUDevice, T>; \ +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Relu<GPUDevice, T>; \ + template struct functor::ReluGrad<GPUDevice, T>; \ + template struct functor::Relu6<GPUDevice, T>; \ + template struct functor::Relu6Grad<GPUDevice, T>; \ + template struct functor::LeakyRelu<GPUDevice, T>; \ + template struct functor::LeakyReluGrad<GPUDevice, T>; \ + template struct functor::Elu<GPUDevice, T>; \ + template struct functor::EluGrad<GPUDevice, T>; \ + template struct functor::Selu<GPUDevice, T>; \ template struct functor::SeluGrad<GPUDevice, T>; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index eab176c7fb..925f5291a6 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -113,74 +113,109 @@ class StatelessRandomOp : public StatelessRandomOpBase { } }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomUniform") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<CPUDevice, random::UniformDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<CPUDevice, random::NormalDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessTruncatedNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp< \ - CPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); +template <typename Device, typename IntType> +class StatelessRandomUniformIntOp : public StatelessRandomOpBase { + public: + using StatelessRandomOpBase::StatelessRandomOpBase; -TF_CALL_half(REGISTER); -TF_CALL_float(REGISTER); -TF_CALL_double(REGISTER); + void Fill(OpKernelContext* context, random::PhiloxRandom random, + Tensor* output) override { + const Tensor& minval = context->input(2); + const Tensor& maxval = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()), + errors::InvalidArgument("minval must be 0-D, got shape ", + minval.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()), + errors::InvalidArgument("maxval must be 0-D, got shape ", + maxval.shape().DebugString())); + + // Verify that minval < maxval. Note that we'll never reach this point for + // empty output. Zero impossible things are fine. + const auto lo = minval.scalar<IntType>()(); + const auto hi = maxval.scalar<IntType>()(); + OP_REQUIRES( + context, lo < hi, + errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi)); + + // Build distribution + typedef random::UniformDistribution<random::PhiloxRandom, IntType> + Distribution; + Distribution dist(lo, hi); + + auto flat = output->flat<IntType>(); + // Reuse the compute kernels from the stateful random ops + functor::FillPhiloxRandom<Device, Distribution>()( + context, context->eigen_device<Device>(), random, flat.data(), + flat.size(), dist); + } +}; -#undef REGISTER +#define REGISTER(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomUniform") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \ + random::PhiloxRandom, TYPE> >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \ + random::PhiloxRandom, TYPE> >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessTruncatedNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp< \ + DEVICE##Device, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); + +#define REGISTER_INT(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomUniformIntOp<DEVICE##Device, TYPE>); + +#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE) +#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE) +#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE) +#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE) + +TF_CALL_half(REGISTER_CPU); +TF_CALL_bfloat16(REGISTER_CPU); +TF_CALL_float(REGISTER_CPU); +TF_CALL_double(REGISTER_CPU); +TF_CALL_int32(REGISTER_INT_CPU); +TF_CALL_int64(REGISTER_INT_CPU); #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomUniform") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<GPUDevice, random::UniformDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<GPUDevice, random::NormalDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessTruncatedNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp< \ - GPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); +TF_CALL_half(REGISTER_GPU); +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); +TF_CALL_int32(REGISTER_INT_GPU); +TF_CALL_int64(REGISTER_INT_GPU); -TF_CALL_half(REGISTER); -TF_CALL_float(REGISTER); -TF_CALL_double(REGISTER); +#endif // GOOGLE_CUDA #undef REGISTER - -#endif // GOOGLE_CUDA +#undef REGISTER_INT +#undef REGISTER_CPU +#undef REGISTER_GPU +#undef REGISTER_INT_CPU +#undef REGISTER_INT_GPU } // namespace diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc index 3559baa18e..3bdcfc90b8 100644 --- a/tensorflow/core/kernels/unique_op.cc +++ b/tensorflow/core/kernels/unique_op.cc @@ -108,7 +108,7 @@ class UniqueOp : public OpKernel { std::unordered_map<T, TIndex> uniq; uniq.reserve(2 * N); - for (int64 i = 0, j = 0; i < N; ++i) { + for (Eigen::Index i = 0, j = 0; i < N; ++i) { auto it = uniq.insert(std::make_pair(Tin(i), j)); idx_vec(i) = it.first->second; if (it.second) { @@ -131,19 +131,20 @@ class UniqueOp : public OpKernel { // General implementation when unique is run over multiple elements. auto Tin = input.shaped<T, 3>(new_sizes); - auto hash_fn = [&Tin](const int64& key) { + auto hash_fn = [&Tin](const Eigen::Index& key) { size_t h = 0; - for (int64 i = 0; i < Tin.dimension(0); i++) { - for (int64 j = 0; j < Tin.dimension(2); j++) { + for (Eigen::Index i = 0; i < Tin.dimension(0); i++) { + for (Eigen::Index j = 0; j < Tin.dimension(2); j++) { h = Hash64Combine(h, hash<T>{}(Tin(i, key, j))); } } return h; }; - auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) { - for (int64 i = 0; i < Tin.dimension(0); i++) { - for (int64 j = 0; j < Tin.dimension(2); j++) { + auto equal_to_fn = [&Tin](const Eigen::Index& lhs, + const Eigen::Index& rhs) { + for (Eigen::Index i = 0; i < Tin.dimension(0); i++) { + for (Eigen::Index j = 0; j < Tin.dimension(2); j++) { if (Tin(i, lhs, j) != Tin(i, rhs, j)) { return false; } diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 780c6f6448..9df0ece69b 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -28981,6 +28981,74 @@ op { } } op { + name: "LeakyRelu" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { + name: "LeakyReluGrad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "backprops" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "LearnedUnigramCandidateSampler" input_arg { name: "true_classes" @@ -70897,6 +70965,62 @@ op { } } op { + name: "StatelessRandomNormal" + input_arg { + name: "shape" + type_attr: "T" + } + input_arg { + name: "seed" + type_attr: "Tseed" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tseed" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "StatelessRandomUniform" input_arg { name: "shape" @@ -70994,6 +71118,118 @@ op { } } op { + name: "StatelessRandomUniform" + input_arg { + name: "shape" + type_attr: "T" + } + input_arg { + name: "seed" + type_attr: "Tseed" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tseed" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { + name: "StatelessRandomUniformInt" + input_arg { + name: "shape" + type_attr: "T" + } + input_arg { + name: "seed" + type_attr: "Tseed" + } + input_arg { + name: "minval" + type_attr: "dtype" + } + input_arg { + name: "maxval" + type_attr: "dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tseed" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "StatelessTruncatedNormal" input_arg { name: "shape" @@ -71091,6 +71327,62 @@ op { } } op { + name: "StatelessTruncatedNormal" + input_arg { + name: "shape" + type_attr: "T" + } + input_arg { + name: "seed" + type_attr: "Tseed" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tseed" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "StatelessWhile" input_arg { name: "input" diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 3eff728f03..a9e5e7824d 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount") .Attr("T: {int32, int64, float32, float64}") .Output("bins: T") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->UnknownShapeOfRank(1)); + ShapeHandle unused; + // The input `size` must be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + + const Tensor* size_tensor = c->input_tensor(1); + if (size_tensor == nullptr) { + // Return unknown shape if size is not known. + c->set_output(0, c->UnknownShapeOfRank(1)); + return Status::OK(); + } + + // Return `[size]` shape if size is known. + int32 size_val = size_tensor->scalar<int32>()(); + if (size_val < 0) { + return errors::InvalidArgument("size (", size_val, + ") must be non-negative"); + } + c->set_output(0, c->MakeShape({size_val})); return Status::OK(); }); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index be4c3ed2b6..05379a7d69 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) { INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?"); INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]"); } + +TEST(MathOpsTest, Bincount_ShapeFn) { + ShapeInferenceTestOp op("Bincount"); + + // size should be scalar. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?"); + + INFER_OK(op, "?;?;?", "[?]"); + INFER_OK(op, "?;[];?", "[?]"); + INFER_OK(op, "[?];[];?", "[?]"); + INFER_OK(op, "[?];[];[?]", "[?]"); +} } // end namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index d1d81b27cc..a9ca69ad86 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -983,6 +983,21 @@ REGISTER_OP("Relu6Grad") .Attr("T: realnumbertype") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); +REGISTER_OP("LeakyRelu") + .Input("features: T") + .Output("activations: T") + .Attr("alpha: float = 0.2") + .Attr("T: {half, float, double} = DT_FLOAT") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("LeakyReluGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("alpha: float = 0.2") + .Attr("T: {half, float, double} = DT_FLOAT") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + REGISTER_OP("Elu") .Input("features: T") .Output("activations: T") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 0d8997c1bd..2048ad26ac 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -14296,6 +14296,74 @@ op { } } op { + name: "LeakyRelu" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { + name: "LeakyReluGrad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "backprops" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "LearnedUnigramCandidateSampler" input_arg { name: "true_classes" @@ -32978,6 +33046,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } @@ -33033,6 +33102,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } @@ -33066,6 +33136,62 @@ op { } } op { + name: "StatelessRandomUniformInt" + input_arg { + name: "shape" + type_attr: "T" + } + input_arg { + name: "seed" + type_attr: "Tseed" + } + input_arg { + name: "minval" + type_attr: "dtype" + } + input_arg { + name: "maxval" + type_attr: "dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tseed" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "StatelessTruncatedNormal" input_arg { name: "shape" @@ -33088,6 +33214,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index adc9cd1486..65bdde375b 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -216,7 +216,8 @@ REGISTER_OP("VarIsInitializedOp") Status VariableShapeShapeFn(InferenceContext* c) { auto* handle_data = c->input_handle_shapes_and_types(0); if (handle_data == nullptr || handle_data->empty()) { - return errors::InvalidArgument("Handle doesn't have shape information."); + c->set_output(0, c->Vector(c->UnknownDim())); + return Status::OK(); } ShapeHandle var_shape = (*handle_data)[0].shape; int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape) diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc index 742709fb18..f919a21d60 100644 --- a/tensorflow/core/ops/stateless_random_ops.cc +++ b/tensorflow/core/ops/stateless_random_ops.cc @@ -19,42 +19,55 @@ limitations under the License. namespace tensorflow { using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -static Status StatelessShape(shape_inference::InferenceContext* context) { +static Status StatelessShape(InferenceContext* c) { // Check seed shape ShapeHandle seed; - TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed)); DimensionHandle unused; - TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused)); // Set output shape ShapeHandle out; - TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out)); - context->set_output(0, out); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); return Status::OK(); } -#define REGISTER_STATELESS_OP(name) \ - REGISTER_OP(name) \ - .Input("shape: T") \ - .Input("seed: Tseed") \ - .Output("output: dtype") \ - .Attr("dtype: {half,float,double} = DT_FLOAT") \ - .Attr("T: {int32, int64} = DT_INT32") \ - .Attr("Tseed: {int32, int64} = DT_INT64") \ +#define REGISTER_STATELESS_OP(name) \ + REGISTER_OP(name) \ + .Input("shape: T") \ + .Input("seed: Tseed") \ + .Output("output: dtype") \ + .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \ + .Attr("T: {int32, int64} = DT_INT32") \ + .Attr("Tseed: {int32, int64} = DT_INT64") \ .SetShapeFn(StatelessShape) -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessRandomUniform"); - -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessRandomNormal"); - -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessTruncatedNormal"); -// This op is exposed through contrib/stateless only. The interface may change. +#undef REGISTER_STATELESS_OP + +REGISTER_OP("StatelessRandomUniformInt") + .Input("shape: T") + .Input("seed: Tseed") + .Input("minval: dtype") + .Input("maxval: dtype") + .Output("output: dtype") + .Attr("dtype: {int32, int64}") + .Attr("T: {int32, int64}") + .Attr("Tseed: {int32, int64} = DT_INT64") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return StatelessShape(c); + }); + REGISTER_OP("StatelessMultinomial") .Input("logits: T") .Input("num_samples: int32") @@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial") return Status::OK(); }); -#undef REGISTER_STATELESS_OP - } // namespace tensorflow diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 8c31468ff5..7ccd54b818 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -83,6 +83,10 @@ message RewriterConfig { // Controls how many times we run the optimizers in meta optimizer (default // is once). NumIterationsType meta_optimizer_iterations = 12; + // Maximum number of milliseconds to spend optimizing a single graph before + // timing out. If equal to 0 the system picks a default (currently 5 minutes). + // If less than 0 the optimizer will never time out. + int64 meta_optimizer_timeout_ms = 20; // The minimum number of nodes in a graph to optimizer. For smaller graphs, // optimization is skipped. |