aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt24
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt46
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc35
-rw-r--r--tensorflow/core/common_runtime/constant_folding.h4
-rw-r--r--tensorflow/core/common_runtime/executor.cc4
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc5
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h5
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc9
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc5
-rw-r--r--tensorflow/core/framework/shape_inference.cc9
-rw-r--r--tensorflow/core/framework/shape_inference.h9
-rw-r--r--tensorflow/core/graph/graph.cc13
-rw-r--r--tensorflow/core/graph/graph.h5
-rw-r--r--tensorflow/core/graph/node_builder.cc8
-rw-r--r--tensorflow/core/grappler/op_types.cc22
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD10
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h19
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h15
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h44
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc116
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer.h21
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc75
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h15
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc62
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc162
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc76
-rw-r--r--tensorflow/core/kernels/data/BUILD14
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc47
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h20
-rw-r--r--tensorflow/core/kernels/data/dataset_utils_test.cc46
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc162
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc196
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc62
-rw-r--r--tensorflow/core/kernels/data/model_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc27
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc79
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc26
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.h2
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/writer_ops.cc12
-rw-r--r--tensorflow/core/kernels/random_op.cc34
-rw-r--r--tensorflow/core/kernels/relu_op.cc153
-rw-r--r--tensorflow/core/kernels/relu_op.h61
-rw-r--r--tensorflow/core/kernels/relu_op_functor.h30
-rw-r--r--tensorflow/core/kernels/relu_op_gpu.cu.cc18
-rw-r--r--tensorflow/core/kernels/stateless_random_ops.cc155
-rw-r--r--tensorflow/core/kernels/unique_op.cc15
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt292
-rw-r--r--tensorflow/core/ops/math_ops.cc19
-rw-r--r--tensorflow/core/ops/math_ops_test.cc12
-rw-r--r--tensorflow/core/ops/nn_ops.cc15
-rw-r--r--tensorflow/core/ops/ops.pbtxt127
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc3
-rw-r--r--tensorflow/core/ops/stateless_random_ops.cc53
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto4
64 files changed, 1913 insertions, 671 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt
new file mode 100644
index 0000000000..280148e032
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "LeakyRelu"
+ visibility: HIDDEN
+ summary: "Computes rectified linear: `max(features, features * alpha)`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt
new file mode 100644
index 0000000000..e427526602
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt
@@ -0,0 +1,24 @@
+op {
+ graph_op_name: "LeakyReluGrad"
+ visibility: HIDDEN
+ in_arg {
+ name: "gradients"
+ description: <<END
+The backpropagated gradients to the corresponding LeakyRelu operation.
+END
+ }
+ in_arg {
+ name: "features"
+ description: <<END
+The features passed as input to the corresponding LeakyRelu operation,
+OR the outputs of that operation (both work equivalently).
+END
+ }
+ out_arg {
+ name: "backprops"
+ description: <<END
+`gradients * (features > 0) + alpha * gradients * (featurs <= 0)`.
+END
+ }
+ summary: "Computes rectified linear gradients for a LeakyRelu operation."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt
new file mode 100644
index 0000000000..b6a6dbdf54
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt
@@ -0,0 +1,46 @@
+op {
+ graph_op_name: "StatelessRandomUniformInt"
+ visibility: HIDDEN
+ in_arg {
+ name: "shape"
+ description: <<END
+The shape of the output tensor.
+END
+ }
+ in_arg {
+ name: "seed"
+ description: <<END
+2 seeds (shape [2]).
+END
+ }
+ in_arg {
+ name: "minval"
+ description: <<END
+Minimum value (inclusive, scalar).
+END
+ }
+ in_arg {
+ name: "maxval"
+ description: <<END
+Maximum value (exclusive, scalar).
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Random values with specified shape.
+END
+ }
+ attr {
+ name: "dtype"
+ description: <<END
+The type of the output.
+END
+ }
+ summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
+ description: <<END
+The generated values follow a uniform distribution in the range `[minval, maxval)`.
+
+The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`.
+END
+}
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index db137f1a19..e81e61b633 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -466,23 +466,23 @@ Graph* GetConstantGraph(
bool ReplaceTensorWithConstant(
Graph* graph, Device* partition_device, NodeAndOutput tensor,
const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
- int64 max_constant_size_in_bytes, bool disable_memory_output_type_check,
+ int64 max_constant_size_in_bytes,
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
// 1) Do not replace another constant.
// 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
- // 3) If the size of the constant in bytes is too large (>
+ // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
+ // constraint, do not replace it.
+ // 4) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
- // 4) If the constant op created does not have a kernel implementation
+ // 5) If the constant op created does not have a kernel implementation
// for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
- // 5) If the constant op for the device has different output memory type
- // from the original op output memory type, do not replace it.
if (tensor.first->IsConstant()) {
return false;
}
@@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant(
return false;
}
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
- if (memory_type == HOST_MEMORY && !is_int32) {
+ if ((memory_type == HOST_MEMORY && !is_int32) ||
+ (memory_type == DEVICE_MEMORY && is_int32)) {
return false;
}
}
@@ -535,25 +536,6 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
- if (!disable_memory_output_type_check) {
- if (partition_device && device_type != DEVICE_CPU) {
- MemoryType original_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
- &original_output_memory_type)
- .ok()) {
- return false;
- }
- MemoryType const_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
- &const_output_memory_type)
- .ok()) {
- return false;
- }
- if (original_output_memory_type != const_output_memory_type) {
- return false;
- }
- }
- }
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
@@ -660,8 +642,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
constant_control_deps[tensors_to_replace[c].first];
if (ReplaceTensorWithConstant(
graph, partition_device, tensors_to_replace[c], outputs[c],
- control_deps, opts.max_constant_size_in_bytes,
- opts.disable_memory_output_type_check, generate_new_name)) {
+ control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
++num_nodes_replaced;
}
}
diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h
index 4c71b7bd27..a9a84f761b 100644
--- a/tensorflow/core/common_runtime/constant_folding.h
+++ b/tensorflow/core/common_runtime/constant_folding.h
@@ -45,10 +45,6 @@ struct ConstantFoldingOptions {
// optimization.
int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
- // If disable_memory_output_type_check is true, we will disable output memory
- // type check for constant node replacement.
- bool disable_memory_output_type_check = false;
-
// A generator for the name suffix of constant folded nodes. A
// default id generator that monotonically increases is used if nullptr is
// passed.
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 2c48084cab..40ec1502da 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@@ -1240,6 +1241,7 @@ class ExecutorState {
StepStatsCollectorInterface* const stats_collector_;
const tracing::TraceCollector* const trace_collector_;
const tracing::EventCollector* const event_collector_;
+ Context context_;
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
@@ -1367,6 +1369,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
trace_collector_(tracing::GetTraceCollector()),
event_collector_(
tracing::GetEventCollector(tracing::EventCategory::kCompute)),
+ context_(ContextKind::kThread),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
impl_(impl),
@@ -1586,6 +1589,7 @@ bool MightTrace(const NodeItem& item,
}
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
+ WithContext wc(context_);
const GraphView& gview = impl_->gview_;
TaggedNodeSeq ready;
TaggedNodeReadyQueue inline_ready;
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 91194bc86f..37a979a8f1 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -39,8 +39,7 @@ void GraphOptimizer::Optimize(
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn,
- const std::function<bool(const Node*)>& cf_consider_fn,
- bool cf_disable_memory_output_type_check) {
+ const std::function<bool(const Node*)>& cf_consider_fn) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@@ -65,8 +64,6 @@ void GraphOptimizer::Optimize(
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
cf_opts.consider = cf_consider_fn;
- cf_opts.disable_memory_output_type_check =
- cf_disable_memory_output_type_check;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 8954e9612d..789cc56942 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -47,16 +47,13 @@ class GraphOptimizer {
// returns true will be considered for CSE.
// If cf_consider_fn is not null then only nodes for which cf_consider_fn
// returns true will be considered for CF.
- // If cf_disable_memory_output_type_check is true, CF will discard output
- // memory type check for constant node replacement.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
- const std::function<bool(const Node*)>& cf_consider_fn = nullptr,
- bool cf_disable_memory_output_type_check = false);
+ const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
const OptimizerOptions& options() { return opts_; }
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index a02084f223..9306386117 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
TF_CHECK_OK(if_op_->input_node(0, &pred_));
+ then_call_builder_.Device(if_op_->requested_device());
+ else_call_builder_.Device(if_op_->requested_device());
}
Status CondBuilder::CreatePivotNodes() {
@@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() {
NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry())
.Input(NodeOut(pred_, 0))
.Input(NodeOut(pred_, 0))
+ .Device(if_op_->requested_device())
.Finalize(graph_, &switch_pred));
control_predecessor_ = switch_pred;
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry())
.Input(switch_pred, kElseBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_f_));
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry())
.Input(switch_pred, kThenBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_t_));
return Status::OK();
}
@@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) {
NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry())
.Input(src, src_output)
.Input(pred_, 0)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &input));
then_call_builder_.Input(input, kThenBranch);
else_call_builder_.Input(input, kElseBranch);
@@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() {
TF_RETURN_IF_ERROR(
NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry())
.Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
+ .Device(if_op_->requested_device())
.Finalize(graph_, &merges[i]));
outputs_[i] = NodeOut(merges[i], 0);
}
@@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
Status CondBuilder::BuildLoweredIfOutput() {
// Build the identity node output.
NodeBuilder ib(name_, "IdentityN");
- ib.Input(outputs_);
+ ib.Input(outputs_).Device(if_op_->requested_device());
return ib.Finalize(graph_, &lowered_if_output_);
}
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index fa4d1eda62..9488a44778 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -288,6 +288,11 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
"output_port '", output_port, "' is out of range, ", "node '",
node->name(), "' has ", node->num_outputs(), " outputs");
}
+ // Note: it's possible, if the node's been updated, that the shape inference
+ // context doesn't have the right number of outputs.
+ if (node->num_outputs() > c->num_outputs()) {
+ TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
+ }
// Check compatibility, and merge the shapes.
ShapeHandle existing_shape = c->output(output_port);
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 3e77028a5f..4dcc80680f 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -239,6 +239,15 @@ void InferenceContext::PreInputInit(
output_handle_shapes_and_types_.resize(num_outputs);
}
+Status InferenceContext::ExpandOutputs(int new_output_size) {
+ if (new_output_size < outputs_.size()) {
+ return errors::InvalidArgument("Trying to reduce number of outputs of op.");
+ }
+ outputs_.resize(new_output_size, nullptr);
+ output_handle_shapes_and_types_.resize(new_output_size);
+ return Status::OK();
+}
+
void InferenceContext::PostInputInit(
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
int num_inputs_from_node_def = 0;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 81258b55b3..e3885b7d9e 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -323,13 +323,13 @@ class InferenceContext {
return input_tensors_as_shapes_;
}
- ShapeHandle output(int64 idx) const { return outputs_[idx]; }
- void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
+ ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
+ void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
Status set_output(StringPiece output_name,
const std::vector<ShapeHandle>& shapes);
int num_outputs() const { return outputs_.size(); }
- ShapeHandle output(int idx) const { return outputs_[idx]; }
+ ShapeHandle output(int idx) const { return outputs_.at(idx); }
Status output(StringPiece output_name,
std::vector<ShapeHandle>* output) const;
@@ -645,6 +645,9 @@ class InferenceContext {
return merged_dims_;
}
+ // Adds new outputs; useful when mutating the graph.
+ Status ExpandOutputs(int new_output_size);
+
private:
// Creates and stores shapes for use in InferenceContext.
class ShapeManager {
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 7a4a0096fa..6f068546d2 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -142,6 +142,19 @@ void Node::Clear() {
assigned_device_name_index_ = 0;
}
+void Node::UpdateProperties() {
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ Status status =
+ InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed at updating node: " << status;
+ return;
+ }
+ props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
+ inputs, outputs);
+}
+
const string& Node::name() const { return props_->node_def.name(); }
const string& Node::type_string() const { return props_->node_def.op(); }
const NodeDef& Node::def() const { return props_->node_def; }
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 2944951f82..228b1331d9 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -171,6 +171,7 @@ class Node {
template <typename T>
void AddAttr(const string& name, const T& val) {
SetAttrValue(val, AddAttrHelper(name));
+ UpdateProperties();
}
void ClearAttr(const string& name);
@@ -211,6 +212,10 @@ class Node {
// e.g. in AddAttr.
void MaybeCopyOnWrite();
+ // Called after an attr has changed. Decides whether we need to update some
+ // property of the node (stored in props_).
+ void UpdateProperties();
+
AttrValue* AddAttrHelper(const string& name);
// A set of mutually exclusive classes for different kinds of nodes,
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index d92874909f..68a20fcc5f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -140,10 +140,10 @@ void NodeBuilder::AddIndexError(const Node* node, int i) {
strings::StrCat("Attempt to add nullptr Node to node with type ",
def_builder_.op_def().name()));
} else {
- errors_.emplace_back(
- strings::StrCat("Attempt to add output ", i, " of ", node->name(),
- " not in range [0, ", node->num_outputs(),
- ") to node with type ", def_builder_.op_def().name()));
+ errors_.emplace_back(strings::StrCat(
+ "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
+ node->num_outputs(), ") to node with type ",
+ def_builder_.op_def().name(), ". Node: ", node->DebugString()));
}
}
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 1b5a215987..cbf5c8e038 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -102,15 +102,19 @@ bool IsConjugateTranspose(const NodeDef& node) {
}
bool IsControlFlow(const NodeDef& node) {
- // clang-format off
- return node.op() == "ControlTrigger" ||
- node.op() == "Enter" ||
- node.op() == "Exit" ||
- node.op() == "LoopCond" ||
- node.op() == "Merge" ||
- node.op() == "NextIteration" ||
- node.op() == "Switch";
- // clang-format on
+ // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative
+ // string comparison.
+ static const gtl::FlatSet<string>* const kControFlowOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "ControlTrigger",
+ "Enter",
+ "Exit",
+ "LoopCond",
+ "Merge",
+ "NextIteration",
+ "Switch",
+ }));
+ return kControFlowOps->count(node.op()) > 0;
}
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 37aa24b947..985d6c6c3a 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -13,9 +13,19 @@ VECTORIZER_DEPS = [
] + tf_protos_all()
cc_library(
+ name = "wrapped_tensor",
+ hdrs = ["wrapped_tensor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ ":wrapped_tensor",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index 3af6bab409..f445157531 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -19,13 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class CastVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
@@ -35,15 +35,17 @@ class CastVectorizer : public Vectorizer {
auto new_cast_node = outer_scope->AddNode(node.def(), &s);
TF_RETURN_IF_ERROR(s);
- // Add input and output mappings
- input_ports->push_back({new_cast_node, 0});
- output_ports->push_back({new_cast_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node,
+ 0);
+
+ // Add output mappings
+ outputs->push_back({new_cast_node, 0, true});
return Status::OK();
}
};
REGISTER_VECTORIZER("Cast", CastVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 74ce520ce1..f1ba741821 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -19,15 +19,15 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class UnpackVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
- if (node.num_inputs() != 1) {
+ if (node.num_inputs() != 1 || inputs.size() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
@@ -39,13 +39,13 @@ class UnpackVectorizer : public Vectorizer {
int new_axis = node.def().attr().at("axis").i() + 1;
new_unpack_node->AddAttr("axis", new_axis);
- // Add the input mappings
- input_ports->push_back({new_unpack_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index,
+ new_unpack_node, 0);
// Add the output mappings
int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- output_ports->push_back({new_unpack_node, i});
+ outputs->push_back({new_unpack_node, i, true});
}
return Status::OK();
@@ -54,6 +54,6 @@ class UnpackVectorizer : public Vectorizer {
REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index 56eb88c95e..8d4676aae0 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -18,15 +18,12 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
-
-// Describes a tensor with its operation Node and output position
-typedef std::pair<Node*, int> Port;
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
@@ -36,17 +33,17 @@ class Vectorizer {
// Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs. The new Node(s) collectively have the
+ // on elements of `inputs`. The new Node(s) collectively have the
// same number of input and output ports as the node being converted.
- // Adds mappings for the new nodes' input and output ports to `inputs` and
- // `outputs` respectively, where the i'th Port in inputs/outputs
- // corresponds to the i'th input/output port of the node to be converted.
+ // Adds edges between the newly created nodes and nodes in `inputs`, and adds
+ // mappings to the new nodes' output ports to `outputs`, where the i'th
+ // value in `outputs` corresponds to the i'th output port of the node
+ // to be converted.
virtual Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) = 0;
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) = 0;
};
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
index a6551e36ac..e1cf77a7d5 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -19,7 +19,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
VectorizerRegistry* VectorizerRegistry::Global() {
static VectorizerRegistry* registry = new VectorizerRegistry;
@@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type,
vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
op_type, std::move(vectorizer)));
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
index 16159d47ca..ad54c74933 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -23,7 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
// A global VectorizerRegistry is used to hold all the vectorizers.
class VectorizerRegistry {
@@ -59,16 +58,12 @@ class VectorizerRegistration {
#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
-#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
- static ::tensorflow::grappler::vectorization_utils:: \
- vectorizer_registration::VectorizerRegistration \
- vectorizer_registration_##ctr( \
- op_type, \
- ::std::unique_ptr< \
- ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
- new vectorizer()))
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorizer_registration:: \
+ VectorizerRegistration vectorizer_registration_##ctr( \
+ op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \
+ new vectorizer()))
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 663ceba027..054aeb9a8f 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -20,13 +20,12 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* inputs,
- std::vector<Port>* outputs) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
return Status::OK();
}
};
@@ -43,10 +42,10 @@ TEST(TestVectorizer, TestTestVectorizer) {
NodeDef node_def;
Status s;
Node* node = g.AddNode(node_def, &s);
- std::vector<Port> inputs, outputs;
- EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok());
+ std::vector<WrappedTensor> inputs, outputs;
+ EXPECT_TRUE(
+ vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok());
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
new file mode 100644
index 0000000000..4439b4ab4e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Represents a tensor that has been vectorized.
+struct WrappedTensor {
+ Node* const node;
+ const int output_index;
+
+ // Whether the tensor is stacked, i.e. represents the results of applying
+ // the operation on all slices of the input, where each row i of the
+ // tensor corresponds to the op's output on slice i of the input. False
+ // if the tensor is not stacked, i.e. represents the result of the op on
+ // a single slice of the input, where the result does not vary between
+ // slices.
+ bool stacked;
+
+ WrappedTensor(Node* node, int output_index, bool stacked)
+ : node(node), output_index(output_index), stacked(stacked) {}
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 344c420902..ba857ab5d9 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -45,22 +45,6 @@ namespace {
// Describes a tensor with its operation Node and output position
typedef std::pair<Node*, int> TensorDesc;
-// Equivalent to python Pfor's WrappedTensor struct
-struct WrappedTensor {
- TensorDesc tensor;
-
- // Whether the tensor is stacked, i.e. represents the results of applying
- // the operation on all slices of the input, where each row i of the
- // tensor corresponds to the op's output on slice i of the input. False
- // if the tensor is not stacked, i.e. represents the result of the op on
- // a single slice of the input, where the result does not vary between
- // slices.
- bool stacked;
-
- WrappedTensor(TensorDesc&& tensor, bool stacked)
- : tensor(std::move(tensor)), stacked(stacked) {}
-};
-
const char* const kRetValOp = "_Retval";
void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
@@ -239,34 +223,48 @@ Status Vectorization::AddConversionMapping(Node* op_node) {
return errors::Unimplemented("No vectorizer registered for op: ",
op_node->type_string());
}
- std::vector<Port> input_ports, output_ports;
- input_ports.reserve(op_node->num_inputs());
- output_ports.reserve(op_node->num_outputs());
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
- &input_ports, &output_ports));
+ std::vector<WrappedTensor> inputs, outputs;
+ inputs.reserve(op_node->num_inputs());
+ outputs.reserve(op_node->num_outputs());
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
- if (op_node->num_outputs() != output_ports.size() ||
- op_node->num_inputs() != input_ports.size() ||
- input_edges.size() != input_ports.size()) {
- return errors::Internal("Vectorizer inputs/outputs don't match.");
- }
-
- // Promote the inputs of the op to MapDefun outputs and connect the edges
- // accordingly.
+ // The inputs for the node to be converted may already have been converted
+ // themselves. For those that are not, we promote them to MapDefun outputs.
for (size_t i = 0; i < op_node->num_inputs(); ++i) {
auto edge = input_edges[i];
- TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
- {edge->src(), edge->src_output()}));
- outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
- input_ports[i].first, input_ports[i].second);
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ inputs.push_back(*found);
+ } else {
+ // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
+ // We assume that all unconverted inputs will be stacked, since we
+ // converted all unstacked nodes in `Initialize`. However, it's actually
+ // possible that yet-unconverted nodes may produce unstacked outputs after
+ // they are vectorized. (For example, see the "Shape" converter in
+ // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
+ // an unstacked input but receives a stacked one, vectorizer->Vectorize
+ // will return an error.
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ int output_index = map_defun_fn_->ret_nodes.size() - 1;
+ inputs.push_back({map_defun_node_, output_index, true});
+ }
+ }
+
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ std::move(inputs), &outputs));
+
+ if (op_node->num_outputs() != outputs.size()) {
+ return errors::Internal(
+ "Number of vectorizer outputs does not match. Expected: ",
+ op_node->num_outputs(), " Actual: ", outputs.size());
}
// Add output mappings.
for (size_t i = 0; i < op_node->num_outputs(); ++i) {
- conversion_map_.insert({{op_node, i}, {std::move(output_ports[i]), true}});
+ conversion_map_.insert({{op_node, i}, outputs[i]});
}
return Status::OK();
@@ -281,25 +279,22 @@ Status Vectorization::ConvertOutput(int output_position) {
TensorDesc output({ret_edge->src(), ret_edge->src_output()});
TensorDesc converted_output;
- if (auto found = gtl::FindOrNull(conversion_map_, output)) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (found->stacked) {
- converted_output = found->tensor;
- } else {
- // Some outputs may be unstacked if they don't derive from arg nodes
- // (for example, if a function returns a constant). For these, we
- // have to add extra nodes to tile it in the 0th dimension.
- TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
- }
- } else {
- // Note: All unstacked nodes are converted ahead of time in `Initialize`,
- // and here we assume that all op vectorizers create only stacked outputs.
- // This may not hold in the future, as more vectorizers are added that
- // may actually create unstacked outputs. For example, see the `Shape`
- // converter in third_party/tensorflow/python/ops/parallel_for/pfor.py
+
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ auto found = gtl::FindOrNull(conversion_map_, output);
+ if (!found) {
TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
- converted_output = conversion_map_.at(output).tensor;
+ found = &conversion_map_.at(output);
+ }
+
+ if (found->stacked) {
+ converted_output = {found->node, found->output_index};
+ } else {
+ // Some outputs may be unstacked if they don't derive from arg nodes
+ // (for example, if a function returns a constant). For these, we
+ // have to add extra nodes to tile it in the 0th dimension.
+ TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
}
ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
@@ -455,7 +450,7 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked,
Node* ones_shape;
TF_RETURN_IF_ERROR(node_builder("Shape")
- .Input(unstacked->tensor.first) // input
+ .Input(unstacked->node) // input
.Finalize(g, &ones_shape));
Node* ones;
@@ -473,8 +468,8 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked,
Node* expand_dims;
TF_RETURN_IF_ERROR(node_builder("ExpandDims")
- .Input(unstacked->tensor.first) // input
- .Input(const_0) // dim
+ .Input(unstacked->node) // input
+ .Input(const_0) // dim
.Finalize(g, &expand_dims));
TF_RETURN_IF_ERROR(node_builder("Tile")
@@ -491,11 +486,11 @@ Status Vectorization::AddArgNodeMappings() {
TF_RETURN_IF_ERROR(map_defun_node_->input_node(
arg_node->attrs().Find("index")->i(), &input_node));
- conversion_map_.insert({{arg_node, 0}, {{input_node, 0}, true}});
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}});
// Control inputs
conversion_map_.insert({{arg_node, Graph::kControlSlot},
- {{input_node, Graph::kControlSlot}, true}});
+ {input_node, Graph::kControlSlot, true}});
}
return Status::OK();
}
@@ -541,7 +536,7 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
if (auto found = gtl::FindOrNull(conversion_map_,
{edge->src(), edge->src_output()})) {
- outer_scope_->AddEdge(found->tensor.first, found->tensor.second, node,
+ outer_scope_->AddEdge(found->node, found->output_index, node,
edge->dst_input());
} else {
status->Update(errors::Internal(
@@ -552,11 +547,10 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
// Add output mappings
for (int i = 0; i < tensor.first->num_outputs(); ++i) {
- conversion_map_.insert(
- {{tensor.first, i}, WrappedTensor({node, i}, false)});
+ conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
}
conversion_map_.insert({{tensor.first, Graph::kControlSlot},
- WrappedTensor({node, Graph::kControlSlot}, false)});
+ WrappedTensor(node, Graph::kControlSlot, false)});
return true;
}
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h
index 765dd13263..bd6bf9f860 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h
@@ -16,8 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
+#include <atomic>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace grappler {
@@ -29,6 +32,7 @@ struct GrapplerItem;
// optimization of a GrapplerItem for running on a cluster.
class GraphOptimizer {
public:
+ GraphOptimizer() : is_cancelled_(false) {}
virtual ~GraphOptimizer() {}
virtual string name() const = 0;
@@ -45,8 +49,25 @@ class GraphOptimizer {
// call to Optimize) performed. Lower "result" scores are better.
virtual void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) = 0;
+
+ // Best effort cancellation. Sets is_cancelled to true and requests that the
+ // optimizer returns as soon as possible from active calls to Optimize() or
+ // FeedBack().
+ void Cancel() { is_cancelled_ = true; }
+
+ bool is_cancelled() const { return is_cancelled_; }
+
+ private:
+ std::atomic<bool> is_cancelled_;
};
+#define GRAPPLER_RETURN_IF_CANCELLED() \
+ do { \
+ if (is_cancelled()) { \
+ return errors::DeadlineExceeded(this->name(), " was cancelled."); \
+ } \
+ } while (0)
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index c3d70a1fdf..7488cedec5 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+
+#include <memory>
+
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@@ -37,7 +40,11 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -107,13 +114,29 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
- MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
+ MK_OPT("pin_to_host",
+ new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
#undef MK_OPT
+MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg)
+ : cpu_device_(cpu_device), cfg_(cfg) {
+ // TODO(rmlarsen): Increase kNumThreads to, say, port::NumSchedulableCPUs()
+ // if we want to the threadpool for parallelizing Grappler
+ const int kNumThreads = 1;
+ thread_pool_ = absl::make_unique<thread::ThreadPool>(
+ Env::Default(), "MetaOptimizerThreadPool", kNumThreads);
+}
+
+MetaOptimizer::~MetaOptimizer() {
+ // The ThreadPool destructor waits for threads to finish, so we don't
+ // pull the rug out from under them.
+ thread_pool_.reset();
+}
+
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
if (cfg_.disable_meta_optimizer()) {
@@ -139,7 +162,7 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
- if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
+ if (cfg_.pin_to_host_optimization() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
}
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
@@ -309,6 +332,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
VLOG(4) << "Starting optimization iteration " << iteration;
for (const auto& optimizer : optimizers) {
+ GRAPPLER_RETURN_IF_CANCELLED();
// Some optimizers can run only once.
if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
// Some must run only on the last iteration.
@@ -367,6 +391,7 @@ Status MetaOptimizer::RunOptimizer(
// resets optimized_graph to an empty graph.
optimized_graph->Swap(&optimized_item->graph);
*optimized_graph = GraphDef();
+ // TODO(rmlarsen): Add timeout for individual optimizers.
Status status =
optimizer->Optimize(cluster, *optimized_item, optimized_graph);
uint64 end_us = Env::Default()->NowMicros();
@@ -388,14 +413,15 @@ Status MetaOptimizer::RunOptimizer(
return status;
}
-Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* optimized_graph) {
+Status MetaOptimizer::OptimizeMainGraphAndFunctionLibrary(
+ Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) {
VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
VLOG(1) << "Optimized main graph.";
+ GRAPPLER_RETURN_IF_CANCELLED();
// Skip optimizing functions if this is a TPU graph. Currently, Grappler
// passes do not handle TPU functions correctly in a variety of ways (Note
@@ -431,6 +457,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimize_function_library = false;
for (const FunctionDef& func : optimized_graph->library().function()) {
+ GRAPPLER_RETURN_IF_CANCELLED();
+
const string& func_name = func.signature().name();
// Skip already optimized functions.
@@ -505,6 +533,43 @@ void MetaOptimizer::PrintResult() {
}
}
+Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ const int64 kFiveMinutesInUsec = 5 * 60 * 1000 * 1000;
+ const int64 timeout_usec = (cfg_.meta_optimizer_timeout_ms() == 0
+ ? kFiveMinutesInUsec
+ : cfg_.meta_optimizer_timeout_ms() * 1000);
+ if (timeout_usec < 0) {
+ return OptimizeMainGraphAndFunctionLibrary(cluster, item, optimized_graph);
+ }
+
+ GraphDef optimized_with_timeout;
+ Status status;
+ Notification done;
+ thread_pool_->Schedule(
+ [this, cluster, &done, &optimized_with_timeout, &item, &status]() {
+ status = this->OptimizeMainGraphAndFunctionLibrary(
+ cluster, item, &optimized_with_timeout);
+ done.Notify();
+ });
+
+ const bool notified = WaitForNotificationWithTimeout(&done, timeout_usec);
+ if (notified && status.ok()) {
+ optimized_graph->Swap(&optimized_with_timeout);
+ } else {
+ *optimized_graph = item.graph;
+ if (!notified) {
+ this->Cancel();
+ done.WaitForNotification();
+ status = errors::DeadlineExceeded(
+ "Grappler MetaOptimizer timed out after ",
+ static_cast<float>(timeout_usec) / (1000 * 1000), " seconds");
+ LOG(WARNING) << status.error_message();
+ }
+ }
+ return status;
+}
+
void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& pruned_graph, double result) {
// Nothing to do for MetaOptimizer.
@@ -527,7 +592,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
- cfg.pin_to_host_optimization() == RewriterConfig::ON ||
+ cfg.pin_to_host_optimization() != RewriterConfig::OFF ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 99a0a33ffa..35d6a4559b 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@@ -28,9 +29,8 @@ namespace grappler {
// Run the other grappler optimizers based on the specified rewriter config.
class MetaOptimizer : public GraphOptimizer {
public:
- MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg)
- : cpu_device_(cpu_device), cfg_(cfg) {}
- ~MetaOptimizer() override = default;
+ MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg);
+ ~MetaOptimizer();
string name() const override { return "meta_optimizer"; };
@@ -65,9 +65,18 @@ class MetaOptimizer : public GraphOptimizer {
Status OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph);
+ // Run optimization passes over the main graph and for functions in the
+ // function library.
+ Status OptimizeMainGraphAndFunctionLibrary(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* optimized_graph);
+
DeviceBase* const cpu_device_; // may be NULL
RewriterConfig cfg_;
+ // Thread pool used for launching optimizers asynchronously.
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+
struct OptimizerResult {
string optimizer_name;
string result;
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 3f3f43382f..7f1dd91f09 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -461,6 +461,68 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites);
}
+class SleepingOptimizer : public CustomGraphOptimizer {
+ public:
+ SleepingOptimizer() {}
+ string name() const override { return "test_optimizer"; }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override {
+ *optimized_graph = item.graph;
+ optimized_graph->add_node();
+ sleep(1);
+ return Status::OK();
+ }
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+};
+
+REGISTER_GRAPH_OPTIMIZER(SleepingOptimizer);
+
+TEST_F(MetaOptimizerTest, OptimizerTimesOut) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("SleepingOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
+ rewriter_config.set_meta_optimizer_timeout_ms(1500);
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ EXPECT_EQ(status.error_message(),
+ "Grappler MetaOptimizer timed out after 1.5 seconds");
+ // Make sure the graph was reverted to the original regardless of when the
+ // optimizer timed out.
+ CompareGraphs(item.graph, output);
+}
+
+TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("SleepingOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
+ rewriter_config.set_meta_optimizer_timeout_ms(1500);
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_EQ(item.graph.node_size() + 1, output.node_size());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 8ed4271fa4..29a3b2b74c 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -25,16 +25,29 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
+
namespace internal {
+namespace {
// TODO(williamchan): Change this constant to be something smarter, maybe
// dynamically determined.
constexpr int64 kTensorMaxSize = 64;
+struct OpDevicePortHasher {
+ std::size_t operator()(const std::tuple<string, string, int>& x) const {
+ uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x)));
+
+ return Hash64Combine(code, hash<int>()(std::get<2>(x)));
+ }
+};
+using OpDevicePortOnHostMap =
+ gtl::FlatMap<std::tuple<string, string, int>, bool, OpDevicePortHasher>;
+
// All the nodes that should be blacklisted and not swapped.
bool IsBlacklisted(const NodeDef& node) {
return
@@ -82,10 +95,10 @@ Status TryFindKernelDef(const std::vector<DeviceType>& devices,
// Checks if a node's output port is host friendly.
// Roughly this means checking if the output port is on Host memory.
-Status IsNodeOutputPortHostFriendly(const GraphView& graph,
- GraphProperties* properties,
- const NodeDef& node, int port_id,
- bool* is_candidate) {
+Status IsNodeOutputPortHostFriendly(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
*is_candidate = false;
// Make sure we are not a blacklisted op.
@@ -117,7 +130,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
for (const auto& fanin : graph.GetFanins(node, false)) {
bool fanin_candidate = false;
TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
- graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
if (!fanin_candidate) {
return Status::OK();
}
@@ -132,11 +146,22 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
return Status::OK();
}
+ // Check `op_device_outport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ const std::tuple<string, string, int> cache_key(node.op(), node.device(),
+ port_id);
+ auto it = op_device_outport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_outport_pinned_to_host_cache->end()) {
+ *is_candidate = it->second;
+ return Status::OK();
+ }
+
// Check if op's output port is pinned to HostMemory.
const OpDef* op = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
if (!s.ok()) {
LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -146,6 +171,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
LOG(WARNING) << "Invalid port: " << port_id << "!\n"
<< node.DebugString() << "\n"
<< op->DebugString();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -155,6 +181,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
&kernel);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -166,22 +193,35 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
}
}
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate);
+
return Status::OK();
}
// Checks if a node's input port is Host friendly.
// Roughly this means checking if the input port is on Host memory.
-bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
+bool IsNodeInputPortHostFriendly(
+ const NodeDef& node, int port_id,
+ OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) {
// If node is on Host, assume its inputs are Host friendly.
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
return true;
}
+ // Check `op_device_inport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ std::tuple<string, string, int> cache_key(node.op(), node.device(), port_id);
+ auto it = op_device_inport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_inport_pinned_to_host_cache->end()) {
+ return it->second;
+ }
+
// Check if op's input port is pinned to HostMemory.
const OpDef* op = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
if (!s.ok()) {
LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
@@ -192,16 +232,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
{node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
// Check if the input_arg is pinned to Host.
for (const string& host_memory_arg : kernel->host_memory_arg()) {
if (op->input_arg(input_arg_id).name() == host_memory_arg) {
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, true);
return true;
}
}
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
+
return false;
}
@@ -211,18 +255,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
// 2] Check if node can run on Host.
// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
// ints, and pinned to Host).
-Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
- const NodeDef& node, bool* is_candidate) {
+Status IsNodeHostCandidate(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
*is_candidate = false;
- // Check if node already on CPU.
- if (str_util::StrContains(node.device(), DEVICE_CPU)) {
- *is_candidate = true;
+ // Skip these node types.
+ if (IsBlacklisted(node)) {
return Status::OK();
}
- // Skip these node types.
- if (IsBlacklisted(node)) {
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
return Status::OK();
}
@@ -232,17 +278,6 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
return Status::OK();
}
- // Check all inputs are Host friendly.
- for (const GraphView::OutputPort& fanin :
- graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
- bool fanin_candidate = false;
- TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
- graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
- if (!fanin_candidate) {
- return Status::OK();
- }
- }
-
// Check all outputs are Host friendly.
if (!properties->has_properties()) {
// This is an expensive call, call it lazily.
@@ -255,16 +290,42 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
}
}
+ // Check all inputs are Host friendly.
+ for (const GraphView::OutputPort& fanin :
+ graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
+ }
+ }
+
*is_candidate = true;
return Status::OK();
}
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device) {
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (const auto& node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
+ node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+} // end namespace
+
+// Tries to swap `device` to a Host device from `devices`. Returns true iff
+// there was a swap.
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device) {
// Force this node onto the CPU.
- if (device.empty() && has_device_cpu) {
- return "/device:CPU:0";
- } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ if (device->empty() && has_device_cpu) {
+ *device = "/device:CPU:0";
+ return true;
+ } else if (str_util::StrContains(*device, DEVICE_GPU)) {
// Sometimes the cluster can have:
// devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
// and we need to handle them properly.
@@ -272,27 +333,19 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices,
{std::pair<string, string>("GPU", "CPU:0"),
std::pair<string, string>("/device", "/device:CPU:0")}) {
const string device_host =
- strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ strings::StrCat(device->substr(0, device->rfind(device_match.first)),
device_match.second);
if (devices.find(device_host) != devices.end()) {
- return device_host;
+ *device = device_host;
+ return true;
}
}
}
- // We couldn't find an appropriate Host device, return original device.
- return device;
-}
-
-bool IsTPUGraphDef(const GraphDef& def) {
- for (const auto& node : def.node()) {
- if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
- node.op() == "TPUPartitionedCall") {
- return true;
- }
- }
+ // We couldn't find an appropriate Host device, return false.
return false;
}
+
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -324,20 +377,26 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// All the Const nodes, and their original devices in topological order.
std::vector<std::pair<NodeDef*, string>> const_nodes;
+ // Cache to map {op, device, port} -> bool on whether it is pinned to host.
+ internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache;
+ internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache;
+
for (auto& node : *optimized_graph->mutable_node()) {
bool is_candidate = false;
- TF_RETURN_IF_ERROR(
- internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
+ TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate(
+ graph, &properties, node, &op_device_outport_pinned_to_host_cache,
+ &is_candidate));
if (!is_candidate) {
continue;
}
- if (IsConstant(node)) {
- const_nodes.emplace_back(&node, node.device());
+ const string original_device = node.device();
+ const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu,
+ node.mutable_device());
+ // Keep track of all Const nodes that we swapped.
+ if (swapped && IsConstant(node)) {
+ const_nodes.emplace_back(&node, original_device);
}
- // Try and swap the device to Host.
- node.set_device(
- internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
}
// Traverse all `const_nodes`, and map them back to GPU greedily.
@@ -349,8 +408,9 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// this node back onto the original device.
for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
// The consumer is not Host friendly, swap it back to the original device.
- if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
- fanout.port_id)) {
+ if (!internal::IsNodeInputPortHostFriendly(
+ *fanout.node, fanout.port_id,
+ &op_device_inport_pinned_to_host_cache)) {
node->set_device(device);
break;
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
index d557a03463..bed4a9ef95 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -26,8 +26,8 @@ namespace tensorflow {
namespace grappler {
namespace internal {
// Try and find an appropriate Host device in `devices` given `device`.
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device);
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device);
} // end namespace internal
// Optimize TensorFlow ops that should be swapped into the CPU to avoid
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 7c64529441..9bb030b220 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -28,30 +28,60 @@ namespace {
class PinToHostOptimizerTest : public GrapplerTest {};
-TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) {
gtl::FlatSet<string> devices = {};
- EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
-
- devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
- "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
- "/device:CPU:0");
-
- devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_CPU:0");
-
- devices = {"/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_GPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_GPU:*");
+
+ string device = "ABC";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "ABC");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:*");
}
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 451f8c1a6c..37c1c54786 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -45,6 +45,16 @@ cc_library(
],
)
+tf_cc_test(
+ name = "dataset_utils_test",
+ srcs = ["dataset_utils_test.cc"],
+ deps = [
+ ":dataset_utils",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
cc_library(
name = "captured_function",
srcs = ["captured_function.cc"],
@@ -205,6 +215,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -232,6 +243,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -245,6 +257,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -285,6 +298,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
":parallel_map_iterator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index e10833f525..a40f7f2146 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -15,10 +15,57 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace data {
+Status ComputeShortCircuitIndices(OpKernelContext* ctx,
+ const NameAttrList& func,
+ std::vector<int>* indices) {
+ FunctionLibraryRuntime::Handle fn_handle;
+ TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
+ func.name(), AttrSlice(&func.attr()), &fn_handle));
+ auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
+ Status s = ctx->function_library()->ReleaseHandle(fn_handle);
+ if (!s.ok()) {
+ LOG(WARNING) << "Failed to release handle: " << s.error_message();
+ }
+ });
+
+ const FunctionBody* fn_body =
+ ctx->function_library()->GetFunctionBody(fn_handle);
+ indices->resize(fn_body->ret_nodes.size());
+ for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
+ Node* ret_node = fn_body->ret_nodes[i];
+ Node* ret_input_node;
+ TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
+ if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i])));
+ } else {
+ indices->clear();
+ break;
+ }
+ }
+ return Status::OK();
+}
+
+std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) {
+ std::map<int, int> last_use;
+ for (size_t i = 0; i < indices.size(); ++i) {
+ last_use[indices[i]] = i;
+ }
+ std::vector<bool> can_move;
+ can_move.resize(indices.size());
+ for (size_t i = 0; i < indices.size(); ++i) {
+ can_move[i] = last_use[indices[i]] == i;
+ }
+ return can_move;
+}
+
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 6ec1350cd4..d777062293 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -22,6 +22,26 @@ limitations under the License.
namespace tensorflow {
namespace data {
+// This method is used to determine whether we can short-circuit the evaluation
+// of the user-defined function `func`. Short-circuting is possible if every
+// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) =
+// (y,x)`, or `f(x) = (x,x)`).
+//
+// If short-circuiting is possible, the method stores the mapping from output
+// indices to input indices in `indices`. Otherwise, `indices` will be empty.
+//
+// Returns non-ok status if analysis of the function fails.
+//
+// TODO(jsimsa): Extend this to support constants as well.
+Status ComputeShortCircuitIndices(OpKernelContext* ctx,
+ const NameAttrList& func,
+ std::vector<int>* indices);
+
+// Given a vector that maps output indices to input indices, return a vector
+// that identifies for which output indices can we move the input (assuming
+// output indices are processed left to right).
+std::vector<bool> ComputeMoveVector(const std::vector<int>& indices);
+
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc
new file mode 100644
index 0000000000..43295b8ebb
--- /dev/null
+++ b/tensorflow/core/kernels/data/dataset_utils_test.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+TEST(DatasetUtils, ComputeMoveVector) {
+ struct TestCase {
+ std::vector<int> indices;
+ std::vector<bool> expected;
+ };
+
+ TestCase test_cases[] = {
+ TestCase{{}, {}},
+ TestCase{{1}, {true}},
+ TestCase{{1, 1}, {false, true}},
+ TestCase{{1, 2}, {true, true}},
+ TestCase{{1, 1, 2}, {false, true, true}},
+ TestCase{{1, 2, 2}, {true, false, true}},
+ };
+
+ for (auto& test_case : test_cases) {
+ EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices));
+ }
+}
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index 00884314a9..be7d182a1f 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -18,9 +18,11 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -31,67 +33,84 @@ namespace {
class FilterDatasetOp : public UnaryDatasetOpKernel {
public:
+ using FilterIteratorPredicate =
+ std::function<Status(IteratorContext*, std::vector<Tensor>, bool*)>;
+
explicit FilterDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- FunctionLibraryRuntime::Handle pred_handle;
- OP_REQUIRES_OK(ctx,
- ctx->function_library()->Instantiate(
- func_.name(), AttrSlice(&func_.attr()), &pred_handle));
- auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() {
- OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle));
- });
-
- const FunctionBody* pred_body =
- ctx->function_library()->GetFunctionBody(pred_handle);
- OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1,
- errors::InvalidArgument(
- "predicate function must have a single return value."));
- Node* ret_node = pred_body->ret_nodes[0];
- Node* ret_input_node;
- OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
-
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- if (ret_input_node->def().op() == "_Arg") {
- int32 index = -1;
- OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index));
- *output = new FilterTensorDataset(ctx, input, func_,
- std::move(captured_func), index);
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+ OP_REQUIRES(ctx, indices.size() <= 1,
+ errors::InvalidArgument(
+ "predicate function has more than one return value."));
+
+ FilterIteratorPredicate filter_pred;
+ if (indices.empty()) {
+ CapturedFunction* raw_captured_func = captured_func.get();
+ filter_pred = [raw_captured_func](IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ bool* out_matched) {
+ std::vector<Tensor> result;
+ TF_RETURN_IF_ERROR(
+ raw_captured_func->RunWithBorrowedArgs(ctx, args, &result));
+
+ if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
+ result[0].NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = result[0].scalar<bool>()();
+ return Status::OK();
+ };
} else {
- *output = new FilterFunctionDataset(ctx, input, func_,
- std::move(captured_func));
+ filter_pred = [indices](IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ bool* out_matched) {
+ const Tensor& predicate = args[indices[0]];
+ if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = predicate.scalar<bool>()();
+ return Status::OK();
+ };
}
+
+ *output = new Dataset(ctx, input, func_, std::move(captured_func),
+ std::move(filter_pred));
}
private:
- const int graph_def_version_;
-
- class FilterDatasetBase : public DatasetBase {
+ class Dataset : public DatasetBase {
public:
- FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func)
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func,
+ FilterIteratorPredicate filter_pred)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
- captured_func_(std::move(captured_func)) {
+ captured_func_(std::move(captured_func)),
+ filter_pred_(std::move(filter_pred)) {
input_->Ref();
}
- ~FilterDatasetBase() override { input_->Unref(); }
+ ~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Filter")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Filter")},
+ filter_pred_);
}
const DataTypeVector& output_dtypes() const override {
@@ -133,17 +152,15 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- virtual Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const = 0;
-
private:
- class Iterator : public DatasetIterator<FilterDatasetBase> {
+ class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params),
+ explicit Iterator(const Params& params,
+ FilterIteratorPredicate filter_pred)
+ : DatasetIterator<Dataset>(params),
filtered_elements_(0),
- dropped_elements_(0) {
+ dropped_elements_(0),
+ filter_pred_(std::move(filter_pred)) {
std::vector<string> components =
str_util::Split(params.prefix, "::", str_util::SkipEmpty());
prefix_end_ = components.back();
@@ -180,8 +197,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- TF_RETURN_IF_ERROR(
- dataset()->EvaluatePredicate(ctx, *out_tensors, &matched));
+ TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched));
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
@@ -251,64 +267,14 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
int64 filtered_elements_ GUARDED_BY(mu_);
int64 dropped_elements_ GUARDED_BY(mu_);
+ const FilterIteratorPredicate filter_pred_;
string prefix_end_;
};
const DatasetBase* const input_;
const NameAttrList func_;
-
- protected:
const std::unique_ptr<CapturedFunction> captured_func_;
- };
-
- class FilterFunctionDataset : public FilterDatasetBase {
- public:
- using FilterDatasetBase::FilterDatasetBase;
-
- protected:
- Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const override {
- // TODO(mrry): Avoid blocking a threadpool thread. We will need to
- // stack-rip the iterators and use async kernels.
- std::vector<Tensor> result;
- TF_RETURN_IF_ERROR(
- captured_func_->RunWithBorrowedArgs(ctx, element, &result));
-
- if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
- result[0].NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = result[0].scalar<bool>()();
- return Status::OK();
- }
- };
-
- class FilterTensorDataset : public FilterDatasetBase {
- public:
- FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func,
- int32 index)
- : FilterDatasetBase(ctx, input, func, std::move(captured_func)),
- index_(index) {}
-
- protected:
- Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const override {
- const Tensor& predicate = element[index_];
- if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = predicate.scalar<bool>()();
- return Status::OK();
- }
-
- private:
- const int32 index_;
+ const FilterIteratorPredicate filter_pred_;
};
private:
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 7a833668ac..8acd6cc724 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -16,10 +16,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
-#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
-#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h"
@@ -27,13 +25,11 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
-#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace data {
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index bf08970560..0fb721cd7c 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/inplace_ops_functor.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -41,6 +43,10 @@ namespace {
// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
+ using MapAndBatchIteratorFunction =
+ std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
+ std::shared_ptr<std::vector<Tensor>>, StatusCallback)>;
+
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
@@ -91,31 +97,73 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
- drop_remainder, output_types_, output_shapes_, func_,
- std::move(captured_func), &ctx->eigen_cpu_device());
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ MapAndBatchIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::shared_ptr<std::vector<Tensor>> out_tensors,
+ StatusCallback done) {
+ raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(),
+ std::move(done), prefix);
+ };
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::shared_ptr<std::vector<Tensor>> out_tensors,
+ StatusCallback done) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ done(Status::OK());
+ };
+ }
+
+ *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls,
+ drop_remainder, output_types_, output_shapes_,
+ std::move(captured_func), &ctx->eigen_cpu_device(),
+ std::move(map_func));
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func, int64 batch_size,
int64 num_parallel_calls, bool drop_remainder,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
- const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
- const Eigen::ThreadPoolDevice* device)
+ const Eigen::ThreadPoolDevice* device,
+ MapAndBatchIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
+ func_(func),
batch_size_(batch_size),
num_parallel_calls_(num_parallel_calls),
drop_remainder_(drop_remainder),
output_types_(output_types),
output_shapes_(output_shapes),
- map_fn_(func),
captured_func_(std::move(captured_func)),
- device_(device) {
+ device_(device),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -123,8 +171,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")},
+ map_func_);
}
const DataTypeVector& output_dtypes() const override {
@@ -143,7 +192,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
@@ -165,7 +214,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
- b->BuildAttrValue(map_fn_, &f);
+ b->BuildAttrValue(func_, &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
@@ -185,12 +234,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
+ explicit Iterator(const Params& params,
+ MapAndBatchIteratorFunction map_func)
: DatasetIterator<Dataset>(params),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
- params.dataset->num_parallel_calls_, mu_, cond_var_)) {}
+ params.dataset->num_parallel_calls_, mu_, cond_var_)),
+ map_func_(std::move(map_func)) {}
~Iterator() override {
mutex_lock l(*mu_);
@@ -297,44 +348,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
int64 num_calls; // access guarded by owner's mutex
};
- void Callback(const std::shared_ptr<IteratorContext>& ctx,
- const std::shared_ptr<BatchResult>& result,
- const std::shared_ptr<std::vector<Tensor>>& return_values,
- int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) {
- result->UpdateStatus(status);
- if (status.ok()) {
- EnsureOutputAllocated(ctx, result, return_values);
- for (size_t i = 0; i < return_values->size(); ++i) {
- const Tensor& tensor = return_values->at(i);
- Tensor* batch = &(result->output)[i];
- if (tensor.NumElements() !=
- (batch->NumElements() / batch->dim_size(0))) {
- TensorShape batch_shape = batch->shape();
- batch_shape.RemoveDim(0);
- result->UpdateStatus(errors::InvalidArgument(
- "Cannot add tensor to the batch: number of elements does not "
- "match. Shapes are: [tensor]: ",
- tensor.shape().DebugString(),
- ", [batch]: ", batch_shape.DebugString()));
- break;
- }
- // TODO(mrry): Add a version of DoParallelConcat that allows us to
- // move `tensor` where possible, to speed up string tensor batching.
- Status copy_status = ::tensorflow::functor::DoParallelConcat(
- *dataset()->device_, tensor, offset, batch);
- if (!copy_status.ok()) {
- result->UpdateStatus(copy_status);
- break;
- }
- }
- {
- mutex_lock l(result->mu);
- result->num_elements++;
- }
- }
- CallCompleted(result);
- }
-
void CallCompleted(const std::shared_ptr<BatchResult>& result)
LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
@@ -363,21 +376,48 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return;
}
- // Call `captured_func_(input_element)`, using `Callback` to store the
- // result in `result`.
- (*ctx->runner())(std::bind(
- [this, result, offset](std::shared_ptr<IteratorContext> ctx,
- std::vector<Tensor> input_element) {
- std::shared_ptr<std::vector<Tensor>> return_values(
- new std::vector<Tensor>());
- dataset()->captured_func_->RunAsync(
- ctx.get(), std::move(input_element), return_values.get(),
- [this, ctx, result, return_values, offset](Status status) {
- Callback(ctx, result, return_values, offset, status);
- },
- prefix());
- },
- ctx, std::move(input_element)));
+ std::shared_ptr<std::vector<Tensor>> return_values =
+ std::make_shared<std::vector<Tensor>>();
+ auto done = [this, ctx, result, return_values, offset](Status status) {
+ result->UpdateStatus(status);
+ if (status.ok()) {
+ EnsureOutputAllocated(ctx, result, return_values);
+ for (size_t i = 0; i < return_values->size(); ++i) {
+ const Tensor& tensor = return_values->at(i);
+ Tensor* batch = &(result->output)[i];
+ if (tensor.NumElements() !=
+ (batch->NumElements() / batch->dim_size(0))) {
+ TensorShape batch_shape = batch->shape();
+ batch_shape.RemoveDim(0);
+ result->UpdateStatus(errors::InvalidArgument(
+ "Cannot add tensor to the batch: number of elements does "
+ "not match. Shapes are: [tensor]: ",
+ tensor.shape().DebugString(),
+ ", [batch]: ", batch_shape.DebugString()));
+ break;
+ }
+ // TODO(mrry): Add a version of DoParallelConcat that allows us to
+ // move `tensor` where possible, to speed up string tensor
+ // batching.
+ Status copy_status = ::tensorflow::functor::DoParallelConcat(
+ *dataset()->device_, tensor, offset, batch);
+ if (!copy_status.ok()) {
+ result->UpdateStatus(copy_status);
+ break;
+ }
+ }
+ {
+ mutex_lock l(result->mu);
+ result->num_elements++;
+ }
+ }
+ CallCompleted(result);
+ };
+
+ // Apply the map function on `input_element`, storing the result in
+ // `return_values`, and invoking `done` when finished.
+ map_func_(ctx.get(), prefix(), std::move(input_element),
+ std::move(return_values), std::move(done));
}
Status CopyPartialBatch(Tensor* output, const Tensor& value,
@@ -404,10 +444,11 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
- runner_thread_.reset(ctx->env()->StartThread(
- {}, "runner_thread",
- std::bind(&Iterator::RunnerThread, this, ctx_copy)));
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
+ runner_thread_ =
+ MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
+ runner_thread_->Schedule(
+ std::bind(&Iterator::RunnerThread, this, ctx_copy));
}
}
@@ -509,8 +550,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
- batch_results_.emplace_back(
- new BatchResult(dataset()->batch_size_));
+ batch_results_.push_back(
+ std::make_shared<BatchResult>(dataset()->batch_size_));
}
int64 offset = call_counter_++ % dataset()->batch_size_;
new_calls.emplace_back(batch_results_.back(), offset);
@@ -527,7 +568,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
- batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
+ batch_results_.push_back(
+ std::make_shared<BatchResult>(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -653,6 +695,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
+ const MapAndBatchIteratorFunction map_func_;
+
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(*mu_) = 0;
// Counts the total number of calls.
@@ -660,7 +704,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results.
std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false;
};
@@ -671,9 +715,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const bool drop_remainder_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
- const NameAttrList map_fn_;
const std::unique_ptr<CapturedFunction> captured_func_;
const Eigen::ThreadPoolDevice* device_; // not owned
+ const MapAndBatchIteratorFunction map_func_;
};
const int op_version_;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index f112e1dc43..6b6ffabf4f 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -17,7 +17,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -28,6 +30,9 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
+ using MapIteratorFunction = std::function<Status(
+ IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>;
+
explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -43,8 +48,42 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ MapIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](IteratorContext* ctx,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors) {
+ return raw_captured_func->Run(ctx, std::move(args), out_tensors);
+ };
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ return Status::OK();
+ };
+ }
+
*output = new Dataset(ctx, input, func_, std::move(captured_func),
- output_types_, output_shapes_);
+ output_types_, output_shapes_, std::move(map_func));
}
private:
@@ -54,13 +93,15 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes)
+ const std::vector<PartialTensorShape>& output_shapes,
+ MapIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
output_types_(output_types),
- output_shapes_(output_shapes) {
+ output_shapes_(output_shapes),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -68,8 +109,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Map")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_);
}
const DataTypeVector& output_dtypes() const override {
@@ -116,8 +157,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ explicit Iterator(const Params& params, MapIteratorFunction map_func)
+ : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -139,10 +180,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- // TODO(mrry): Avoid blocking a threadpool thread. We will need to
- // stack-rip the iterators and use async kernels.
- Status s =
- dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
+ Status s = map_func_(ctx, args, out_tensors);
if (errors::IsOutOfRange(s)) {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
@@ -167,6 +205,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
std::unique_ptr<IteratorBase> input_impl_;
+ const MapIteratorFunction map_func_;
};
const DatasetBase* const input_;
@@ -174,6 +213,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const MapIteratorFunction map_func_;
};
DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
index 9aa505f4f1..859df57962 100644
--- a/tensorflow/core/kernels/data/model_dataset_op.cc
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -126,9 +127,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!optimize_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- optimize_thread_.reset(ctx->env()->StartThread(
- {}, "optimize_thread",
- [this, new_ctx]() { OptimizeThread(new_ctx); }));
+ optimize_thread_ =
+ MakeUnique<BackgroundWorker>(ctx->env(), "optimize_thread");
+ optimize_thread_->Schedule(
+ [this, new_ctx]() { OptimizeThread(new_ctx); });
}
return Status::OK();
}
@@ -167,7 +169,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
mutex mu_;
condition_variable cond_var_;
std::shared_ptr<model::Model> model_;
- std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_);
+ std::unique_ptr<BackgroundWorker> optimize_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 6b6b3d6ab9..9c836b836e 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -481,9 +482,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- worker_threads_.emplace_back(ctx->env()->StartThread(
- {}, "worker_thread",
- [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
+ worker_threads_.emplace_back(
+ MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread"));
+ worker_threads_.back()->Schedule(
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); });
}
}
return Status::OK();
@@ -580,9 +582,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
workers_[i].SetInputs(s, std::move(args));
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- worker_threads_.emplace_back(ctx->env()->StartThread(
- {}, "worker_thread",
- [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
+ worker_threads_.emplace_back(
+ MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread"));
+ worker_threads_.back()->Schedule(
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); });
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
@@ -1047,7 +1050,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// The worker threads. This must be last to ensure the
// threads have exited before any other members are deallocated.
// TODO(b/65178177): Avoid allocating additional threads.
- std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<BackgroundWorker>> worker_threads_
+ GUARDED_BY(mu_);
};
const DatasetBase* const input_;
@@ -1389,9 +1393,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- runner_thread_.reset(ctx->env()->StartThread(
- {}, "runner_thread",
- [this, new_ctx]() { RunnerThread(new_ctx); }));
+ runner_thread_ =
+ MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
+ runner_thread_->Schedule(
+ [this, new_ctx]() { RunnerThread(new_ctx); });
}
}
@@ -1645,7 +1650,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
// Identifies whether background activity should be cancelled.
bool cancelled_ GUARDED_BY(*mu_) = false;
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 6abe6c8338..3a14924fba 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/random/random.h"
@@ -56,9 +57,55 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ ParallelMapIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ raw_captured_func->RunAsync(ctx, std::move(args), out_tensors,
+ std::move(done), prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args),
+ out_tensors, std::move(done)));
+ };
+ }
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args, std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ done(Status::OK());
+ };
+ }
+
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
output_shapes_, use_inter_op_parallelism_,
- std::move(captured_func));
+ std::move(captured_func), std::move(map_func));
}
private:
@@ -69,7 +116,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
bool use_inter_op_parallelism,
- std::unique_ptr<CapturedFunction> captured_func)
+ std::unique_ptr<CapturedFunction> captured_func,
+ ParallelMapIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
@@ -77,7 +125,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
output_types_(output_types),
output_shapes_(output_shapes),
use_inter_op_parallelism_(use_inter_op_parallelism),
- captured_func_(std::move(captured_func)) {
+ captured_func_(std::move(captured_func)),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -89,26 +138,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
- ParallelMapIteratorFunction map_func =
- [this, new_prefix](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done), new_prefix);
- };
- if (!use_inter_op_parallelism_) {
- map_func = [map_func](
- IteratorContext* ctx, std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
- result, std::move(done)));
- };
- }
-
- return NewParallelMapIterator({this, new_prefix}, input_,
- std::move(init_func), std::move(map_func),
- num_parallel_calls_);
+ return NewParallelMapIterator(
+ {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
+ std::move(init_func), map_func_, num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
@@ -176,6 +208,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
const bool use_inter_op_parallelism_;
const std::unique_ptr<CapturedFunction> captured_func_;
+ const ParallelMapIteratorFunction map_func_;
};
DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 13bd4b6036..e69274e4f2 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -179,10 +180,11 @@ class ParallelMapIterator : public DatasetBaseIterator {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
- runner_thread_.reset(ctx->env()->StartThread(
- {}, "runner_thread",
- std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
+ runner_thread_ =
+ MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
+ runner_thread_->Schedule(
+ std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
}
}
@@ -208,15 +210,15 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
- // Call `func_(input_element)`, store the result in `result->return_values`,
- // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
};
- map_func_(ctx.get(), std::move(input_element), &result->return_values,
- std::move(done));
+ // Apply the map function on `input_element`, storing the result in
+ // `result->return_values`, and invoking `done` when finished.
+ map_func_(ctx.get(), prefix(), std::move(input_element),
+ &result->return_values, std::move(done));
}
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
@@ -330,7 +332,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
GUARDED_BY(*mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false;
};
@@ -349,9 +351,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
- return std::unique_ptr<IteratorBase>(
- new ParallelMapIterator(params, input_dataset, std::move(init_func),
- std::move(map_func), num_parallel_calls));
+ return MakeUnique<ParallelMapIterator>(
+ params, input_dataset, std::move(init_func), std::move(map_func),
+ num_parallel_calls);
}
} // namespace data
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
index dc26c5cf25..813f13c9e4 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.h
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -30,7 +30,7 @@ namespace data {
// 3. A `std::vector<Tensor>*` to which the function will write the result.
// 4. A `StatusCallback` that should be invoked when the function is complete.
using ParallelMapIteratorFunction =
- std::function<void(IteratorContext*, std::vector<Tensor>,
+ std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
std::vector<Tensor>*, StatusCallback)>;
// Returns a new iterator that applies `map_func` to the elements of
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 1d1a717062..7de5ea8860 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -182,7 +182,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- auto map_fn = [this](IteratorContext* ctx,
+ auto map_fn = [this](IteratorContext* ctx, const string& prefix,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
(*ctx->runner())([this, ctx, input_element, result, done]() {
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 754ed772db..e9c38eb8a0 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -256,10 +257,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) {
+ prefetch_thread_ =
+ MakeUnique<BackgroundWorker>(ctx->env(), "prefetch_thread");
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- prefetch_thread_.reset(ctx->env()->StartThread(
- {}, "prefetch_thread",
- [this, new_ctx]() { PrefetchThread(new_ctx); }));
+ prefetch_thread_->Schedule(
+ [this, new_ctx]() { PrefetchThread(new_ctx); });
}
return Status::OK();
}
@@ -363,7 +365,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
string prefix_end_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
- std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
+ std::unique_ptr<BackgroundWorker> prefetch_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
bool prefetch_thread_finished_ GUARDED_BY(mu_) = false;
};
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 66466d6a36..9f54c381a9 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -485,7 +485,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
int64 buffer_size, int64 seed, int64 seed2, int64 count)
: ShuffleDatasetBase(ctx, input, buffer_size, count),
seed_(seed),
- seed2_(seed) {}
+ seed2_(seed2) {}
string DebugString() const override {
return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc
index 3f76695bb1..7bb2077b62 100644
--- a/tensorflow/core/kernels/data/writer_ops.cc
+++ b/tensorflow/core/kernels/data/writer_ops.cc
@@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel {
public:
explicit ToTFRecordOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
- thread_pool_(new thread::ThreadPool(
- ctx->env(), ThreadOptions(),
- strings::StrCat("to_tf_record__op_", SanitizeThreadSuffix(name())),
- 1 /* num_threads */, false /* low_latency_hint */)) {}
+ background_worker_(
+ ctx->env(),
+ strings::StrCat("to_tf_record_op_", SanitizeThreadSuffix(name()))) {
+ }
template <typename T>
Status ParseScalarArgument(OpKernelContext* ctx,
@@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel {
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
- thread_pool_->Schedule([this, ctx, done]() {
+ background_worker_.Schedule([this, ctx, done]() {
string filename;
OP_REQUIRES_OK_ASYNC(
ctx, ParseScalarArgument<string>(ctx, "filename", &filename), done);
@@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel {
}
private:
- std::unique_ptr<thread::ThreadPool> thread_pool_;
+ BackgroundWorker background_worker_;
};
REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index 04a53697c0..3810d817ca 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -489,13 +489,15 @@ class RandomGammaOp : public OpKernel {
Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
RandomGammaOp<TYPE>)
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<CPUDevice, IntType>);
TF_CALL_half(REGISTER);
@@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT);
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<int32>("T") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<int32>("T") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<GPUDevice, IntType>);
TF_CALL_half(REGISTER);
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index 173fea37ed..e67695d54a 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -33,19 +33,25 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
-#define REGISTER_RELU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ReluOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ReluGradOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- Relu6Op<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- Relu6GradOp<CPUDevice, type>)
+#define REGISTER_RELU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluGradOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6Op<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<CPUDevice, type>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ LeakyReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<CPUDevice, type>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
#undef REGISTER_RELU_KERNELS
@@ -99,6 +105,19 @@ namespace functor {
extern template struct Relu6Grad<GPUDevice, T>; \
\
template <> \
+ void LeakyRelu<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor features, T alpha, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct LeakyRelu<GPUDevice, T>; \
+ \
+ template <> \
+ void LeakyReluGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor features, T alpha, \
+ typename TTypes<T>::Tensor backprops); \
+ extern template struct LeakyReluGrad<GPUDevice, T>; \
+ \
+ template <> \
void Elu<GPUDevice, T>::operator()(const GPUDevice& d, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
@@ -134,30 +153,36 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
// Registration of the GPU implementations.
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- ReluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- ReluGradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- Relu6Op<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- Relu6GradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluGradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- SeluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6Op<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ LeakyReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SeluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SeluGradOp<GPUDevice, type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
@@ -188,30 +213,36 @@ REGISTER_KERNEL_BUILDER(
#ifdef TENSORFLOW_USE_SYCL
// Registration of the GPU implementations.
-#define REGISTER_SYCL_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- ReluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- ReluGradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- Relu6Op<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- Relu6GradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluGradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- SeluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6Op<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6GradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ LeakyReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SeluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SeluGradOp<SYCLDevice, type>)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index 4775deeb61..a4638c70c2 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -132,6 +132,67 @@ void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
}
template <typename Device, typename T>
+class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
+ public:
+ explicit LeakyReluOp(OpKernelConstruction* context)
+ : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::LeakyRelu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(), alpha_,
+ output->flat<T>());
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+class LeakyReluGradOp
+ : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> {
+ public:
+ explicit LeakyReluGradOp(OpKernelConstruction* context)
+ : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, T alpha, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): either the inputs that were passed to LeakyReluOp(), or its
+ // outputs (using either one yields the same result here).
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, alpha_, output);
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g,
+ const Tensor& a, T alpha,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::LeakyReluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
+ output->flat<T>());
+};
+
+template <typename Device, typename T>
class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
public:
using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
index e564da335a..f917142a12 100644
--- a/tensorflow/core/kernels/relu_op_functor.h
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -91,6 +91,36 @@ struct Relu6Grad {
}
};
+// Functor used by LeakyReluOp to do the computations.
+template <typename Device, typename T>
+struct LeakyRelu {
+ // Computes LeakyRelu activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ T alpha, typename TTypes<T>::Tensor activations) {
+ activations.device(d) = features.cwiseMax(features * alpha);
+ }
+};
+
+// Functor used by LeakyReluGradOp to do the computations.
+template <typename Device, typename T>
+struct LeakyReluGrad {
+ // Computes LeakyReluGrad backprops.
+ //
+ // gradients: gradients backpropagated to the LeakyRelu op.
+ // features: either the inputs that were passed to the LeakyRelu or, or its
+ // outputs (using either one yields the same result here).
+ // backprops: gradients to backpropagate to the LeakyRelu inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features, T alpha,
+ typename TTypes<T>::Tensor backprops) {
+ backprops.device(d) =
+ (features > static_cast<T>(0)).select(gradients, gradients * alpha);
+ }
+};
+
// Functor used by EluOp to do the computations.
template <typename Device, typename T>
struct Elu {
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index b9391517c1..dd5f9495e2 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -145,14 +145,16 @@ struct Relu<Device, qint8> {
} // namespace functor
// Definition of the GPU implementations declared in relu_op.cc.
-#define DEFINE_GPU_KERNELS(T) \
- template struct functor::Relu<GPUDevice, T>; \
- template struct functor::ReluGrad<GPUDevice, T>; \
- template struct functor::Relu6<GPUDevice, T>; \
- template struct functor::Relu6Grad<GPUDevice, T>; \
- template struct functor::Elu<GPUDevice, T>; \
- template struct functor::EluGrad<GPUDevice, T>; \
- template struct functor::Selu<GPUDevice, T>; \
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::Relu<GPUDevice, T>; \
+ template struct functor::ReluGrad<GPUDevice, T>; \
+ template struct functor::Relu6<GPUDevice, T>; \
+ template struct functor::Relu6Grad<GPUDevice, T>; \
+ template struct functor::LeakyRelu<GPUDevice, T>; \
+ template struct functor::LeakyReluGrad<GPUDevice, T>; \
+ template struct functor::Elu<GPUDevice, T>; \
+ template struct functor::EluGrad<GPUDevice, T>; \
+ template struct functor::Selu<GPUDevice, T>; \
template struct functor::SeluGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc
index eab176c7fb..925f5291a6 100644
--- a/tensorflow/core/kernels/stateless_random_ops.cc
+++ b/tensorflow/core/kernels/stateless_random_ops.cc
@@ -113,74 +113,109 @@ class StatelessRandomOp : public StatelessRandomOpBase {
}
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomUniform") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<CPUDevice, random::UniformDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomNormal") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<CPUDevice, random::NormalDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessTruncatedNormal") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp< \
- CPUDevice, \
- random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+template <typename Device, typename IntType>
+class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
+ public:
+ using StatelessRandomOpBase::StatelessRandomOpBase;
-TF_CALL_half(REGISTER);
-TF_CALL_float(REGISTER);
-TF_CALL_double(REGISTER);
+ void Fill(OpKernelContext* context, random::PhiloxRandom random,
+ Tensor* output) override {
+ const Tensor& minval = context->input(2);
+ const Tensor& maxval = context->input(3);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()),
+ errors::InvalidArgument("minval must be 0-D, got shape ",
+ minval.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()),
+ errors::InvalidArgument("maxval must be 0-D, got shape ",
+ maxval.shape().DebugString()));
+
+ // Verify that minval < maxval. Note that we'll never reach this point for
+ // empty output. Zero impossible things are fine.
+ const auto lo = minval.scalar<IntType>()();
+ const auto hi = maxval.scalar<IntType>()();
+ OP_REQUIRES(
+ context, lo < hi,
+ errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
+
+ // Build distribution
+ typedef random::UniformDistribution<random::PhiloxRandom, IntType>
+ Distribution;
+ Distribution dist(lo, hi);
+
+ auto flat = output->flat<IntType>();
+ // Reuse the compute kernels from the stateful random ops
+ functor::FillPhiloxRandom<Device, Distribution>()(
+ context, context->eigen_device<Device>(), random, flat.data(),
+ flat.size(), dist);
+ }
+};
-#undef REGISTER
+#define REGISTER(DEVICE, TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessRandomUniform") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessRandomNormal") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessTruncatedNormal") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp< \
+ DEVICE##Device, \
+ random::TruncatedNormalDistribution< \
+ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+
+#define REGISTER_INT(DEVICE, TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomUniformIntOp<DEVICE##Device, TYPE>);
+
+#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
+#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
+#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
+#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
+
+TF_CALL_half(REGISTER_CPU);
+TF_CALL_bfloat16(REGISTER_CPU);
+TF_CALL_float(REGISTER_CPU);
+TF_CALL_double(REGISTER_CPU);
+TF_CALL_int32(REGISTER_INT_CPU);
+TF_CALL_int64(REGISTER_INT_CPU);
#if GOOGLE_CUDA
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomUniform") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<GPUDevice, random::UniformDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomNormal") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<GPUDevice, random::NormalDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessTruncatedNormal") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp< \
- GPUDevice, \
- random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+TF_CALL_int32(REGISTER_INT_GPU);
+TF_CALL_int64(REGISTER_INT_GPU);
-TF_CALL_half(REGISTER);
-TF_CALL_float(REGISTER);
-TF_CALL_double(REGISTER);
+#endif // GOOGLE_CUDA
#undef REGISTER
-
-#endif // GOOGLE_CUDA
+#undef REGISTER_INT
+#undef REGISTER_CPU
+#undef REGISTER_GPU
+#undef REGISTER_INT_CPU
+#undef REGISTER_INT_GPU
} // namespace
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index 3559baa18e..3bdcfc90b8 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -108,7 +108,7 @@ class UniqueOp : public OpKernel {
std::unordered_map<T, TIndex> uniq;
uniq.reserve(2 * N);
- for (int64 i = 0, j = 0; i < N; ++i) {
+ for (Eigen::Index i = 0, j = 0; i < N; ++i) {
auto it = uniq.insert(std::make_pair(Tin(i), j));
idx_vec(i) = it.first->second;
if (it.second) {
@@ -131,19 +131,20 @@ class UniqueOp : public OpKernel {
// General implementation when unique is run over multiple elements.
auto Tin = input.shaped<T, 3>(new_sizes);
- auto hash_fn = [&Tin](const int64& key) {
+ auto hash_fn = [&Tin](const Eigen::Index& key) {
size_t h = 0;
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
+ for (Eigen::Index i = 0; i < Tin.dimension(0); i++) {
+ for (Eigen::Index j = 0; j < Tin.dimension(2); j++) {
h = Hash64Combine(h, hash<T>{}(Tin(i, key, j)));
}
}
return h;
};
- auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) {
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
+ auto equal_to_fn = [&Tin](const Eigen::Index& lhs,
+ const Eigen::Index& rhs) {
+ for (Eigen::Index i = 0; i < Tin.dimension(0); i++) {
+ for (Eigen::Index j = 0; j < Tin.dimension(2); j++) {
if (Tin(i, lhs, j) != Tin(i, rhs, j)) {
return false;
}
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 780c6f6448..9df0ece69b 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -28981,6 +28981,74 @@ op {
}
}
op {
+ name: "LeakyRelu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
+ name: "LeakyReluGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "LearnedUnigramCandidateSampler"
input_arg {
name: "true_classes"
@@ -70897,6 +70965,62 @@ op {
}
}
op {
+ name: "StatelessRandomNormal"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessRandomUniform"
input_arg {
name: "shape"
@@ -70994,6 +71118,118 @@ op {
}
}
op {
+ name: "StatelessRandomUniform"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
+ name: "StatelessRandomUniformInt"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ input_arg {
+ name: "minval"
+ type_attr: "dtype"
+ }
+ input_arg {
+ name: "maxval"
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessTruncatedNormal"
input_arg {
name: "shape"
@@ -71091,6 +71327,62 @@ op {
}
}
op {
+ name: "StatelessTruncatedNormal"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessWhile"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 3eff728f03..a9e5e7824d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount")
.Attr("T: {int32, int64, float32, float64}")
.Output("bins: T")
.SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->UnknownShapeOfRank(1));
+ ShapeHandle unused;
+ // The input `size` must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+
+ const Tensor* size_tensor = c->input_tensor(1);
+ if (size_tensor == nullptr) {
+ // Return unknown shape if size is not known.
+ c->set_output(0, c->UnknownShapeOfRank(1));
+ return Status::OK();
+ }
+
+ // Return `[size]` shape if size is known.
+ int32 size_val = size_tensor->scalar<int32>()();
+ if (size_val < 0) {
+ return errors::InvalidArgument("size (", size_val,
+ ") must be non-negative");
+ }
+ c->set_output(0, c->MakeShape({size_val}));
return Status::OK();
});
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index be4c3ed2b6..05379a7d69 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
}
+
+TEST(MathOpsTest, Bincount_ShapeFn) {
+ ShapeInferenceTestOp op("Bincount");
+
+ // size should be scalar.
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
+
+ INFER_OK(op, "?;?;?", "[?]");
+ INFER_OK(op, "?;[];?", "[?]");
+ INFER_OK(op, "[?];[];?", "[?]");
+ INFER_OK(op, "[?];[];[?]", "[?]");
+}
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index d1d81b27cc..a9ca69ad86 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -983,6 +983,21 @@ REGISTER_OP("Relu6Grad")
.Attr("T: realnumbertype")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+REGISTER_OP("LeakyRelu")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("alpha: float = 0.2")
+ .Attr("T: {half, float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("LeakyReluGrad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("alpha: float = 0.2")
+ .Attr("T: {half, float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+
REGISTER_OP("Elu")
.Input("features: T")
.Output("activations: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 0d8997c1bd..2048ad26ac 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -14296,6 +14296,74 @@ op {
}
}
op {
+ name: "LeakyRelu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
+ name: "LeakyReluGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "LearnedUnigramCandidateSampler"
input_arg {
name: "true_classes"
@@ -32978,6 +33046,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -33033,6 +33102,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -33066,6 +33136,62 @@ op {
}
}
op {
+ name: "StatelessRandomUniformInt"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ input_arg {
+ name: "minval"
+ type_attr: "dtype"
+ }
+ input_arg {
+ name: "maxval"
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessTruncatedNormal"
input_arg {
name: "shape"
@@ -33088,6 +33214,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index adc9cd1486..65bdde375b 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -216,7 +216,8 @@ REGISTER_OP("VarIsInitializedOp")
Status VariableShapeShapeFn(InferenceContext* c) {
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data == nullptr || handle_data->empty()) {
- return errors::InvalidArgument("Handle doesn't have shape information.");
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
}
ShapeHandle var_shape = (*handle_data)[0].shape;
int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape)
diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc
index 742709fb18..f919a21d60 100644
--- a/tensorflow/core/ops/stateless_random_ops.cc
+++ b/tensorflow/core/ops/stateless_random_ops.cc
@@ -19,42 +19,55 @@ limitations under the License.
namespace tensorflow {
using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-static Status StatelessShape(shape_inference::InferenceContext* context) {
+static Status StatelessShape(InferenceContext* c) {
// Check seed shape
ShapeHandle seed;
- TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
DimensionHandle unused;
- TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
// Set output shape
ShapeHandle out;
- TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out));
- context->set_output(0, out);
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ c->set_output(0, out);
return Status::OK();
}
-#define REGISTER_STATELESS_OP(name) \
- REGISTER_OP(name) \
- .Input("shape: T") \
- .Input("seed: Tseed") \
- .Output("output: dtype") \
- .Attr("dtype: {half,float,double} = DT_FLOAT") \
- .Attr("T: {int32, int64} = DT_INT32") \
- .Attr("Tseed: {int32, int64} = DT_INT64") \
+#define REGISTER_STATELESS_OP(name) \
+ REGISTER_OP(name) \
+ .Input("shape: T") \
+ .Input("seed: Tseed") \
+ .Output("output: dtype") \
+ .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
+ .Attr("T: {int32, int64} = DT_INT32") \
+ .Attr("Tseed: {int32, int64} = DT_INT64") \
.SetShapeFn(StatelessShape)
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomUniform");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomNormal");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessTruncatedNormal");
-// This op is exposed through contrib/stateless only. The interface may change.
+#undef REGISTER_STATELESS_OP
+
+REGISTER_OP("StatelessRandomUniformInt")
+ .Input("shape: T")
+ .Input("seed: Tseed")
+ .Input("minval: dtype")
+ .Input("maxval: dtype")
+ .Output("output: dtype")
+ .Attr("dtype: {int32, int64}")
+ .Attr("T: {int32, int64}")
+ .Attr("Tseed: {int32, int64} = DT_INT64")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ return StatelessShape(c);
+ });
+
REGISTER_OP("StatelessMultinomial")
.Input("logits: T")
.Input("num_samples: int32")
@@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial")
return Status::OK();
});
-#undef REGISTER_STATELESS_OP
-
} // namespace tensorflow
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 8c31468ff5..7ccd54b818 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -83,6 +83,10 @@ message RewriterConfig {
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
NumIterationsType meta_optimizer_iterations = 12;
+ // Maximum number of milliseconds to spend optimizing a single graph before
+ // timing out. If equal to 0 the system picks a default (currently 5 minutes).
+ // If less than 0 the optimizer will never time out.
+ int64 meta_optimizer_timeout_ms = 20;
// The minimum number of nodes in a graph to optimizer. For smaller graphs,
// optimization is skipped.