aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-06-13 13:05:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 13:08:14 -0700
commit642a043de4901ddbf305db105168b8908adfe99e (patch)
tree30c5e1be065efaf9e0ec0fbe660ab0978b4736a4
parent8051c4b7790bb3cc64bf14d1180ab2ad55f0c032 (diff)
[TF:XLA] Replace bespoke NodeSlot class in subgraph encapsulation code with InputTensor and OutputTensor classes from TF core.
Add equality and hash methods to InputTensor and OutputTensor. No functional changes intended. PiperOrigin-RevId: 200440015
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc127
-rw-r--r--tensorflow/core/graph/graph.cc23
-rw-r--r--tensorflow/core/graph/graph.h20
3 files changed, 97 insertions, 73 deletions
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index ea90d714c8..edd2247694 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -106,41 +106,11 @@ void MarkGuaranteedConstants(
}
}
-// A node/slot pair.
-// TODO(phawkins): is there a common definition of this?
-struct NodeSlot {
- NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {}
- NodeSlot(const Node* node, int slot)
- : node(node), slot(slot), dtype(DT_INVALID) {}
- NodeSlot(const Node* node, int slot, DataType dtype)
- : node(node), slot(slot), dtype(dtype) {}
-
- const Node* node;
- int slot;
-
- // Optional: used to record the destination type of a source NodeSlot in case
- // the source output is a Ref type that is cast to a Tensor at the
- // destination.
- DataType dtype;
-
- bool operator==(const NodeSlot& other) const {
- return node == other.node && slot == other.slot && dtype == other.dtype;
- }
-
- // Leave dtype out of the hash since there are never two NodeSlots with the
- // same node and slot and different dtypes.
- struct Hasher {
- uint64 operator()(NodeSlot const& s) const {
- return Hash64Combine(std::hash<const Node*>()(s.node),
- std::hash<int>()(s.slot));
- }
- };
-
- struct PairHasher {
- uint64 operator()(std::pair<NodeSlot, NodeSlot> const& s) const {
- return Hash64Combine(Hasher()(s.first), Hasher()(s.second));
- }
- };
+struct OutputInputTensorPairHasher {
+ uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
+ return Hash64Combine(OutputTensor::Hash()(s.first),
+ InputTensor::Hash()(s.second));
+ }
};
// TODO(phawkins) add a canonical copy of these operator names and refactor
@@ -376,7 +346,7 @@ class Encapsulator {
// Map from source (producer node/slot) tensors in the original graph to
// input index (slot number in the HostCompute/RecvAtHost nodes that will
// be created) for the outside_compilation subgraph.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash> inputs;
// Set of nodes in the original graph that are the source of control edges
// that cross from the containing compiled subgraph into the
@@ -392,8 +362,15 @@ class Encapsulator {
// node/slot) tensors in the original graph to output index (slot number
// in the SendFromHost/HostCompute nodes that will be created) for the
// outside_compilation subgraph.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
+ struct ArgNumAndType {
+ int index;
+ DataType dtype;
+
+ ArgNumAndType(int i, DataType t) : index(i), dtype(t) {}
+ };
+ std::unordered_map<OutputTensor, ArgNumAndType, OutputTensor::Hash>
+ outputs_by_src;
+ std::unordered_map<InputTensor, int, InputTensor::Hash> outputs_by_dst;
// Set of nodes in the original graph that are the destination of control
// edges that cross from the outside_compilation subgraph into the
@@ -479,14 +456,14 @@ class Encapsulator {
// (consumer node/slot) tensors in the input graph to _Arg numbers in
// the subgraph. The source map is one-to-one, whereas the dest map may be
// many-to-one.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_src_;
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_dst_;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
+ std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
// The _Arg nodes in the subgraph, in order by argument number.
std::vector<Node*> args_;
// Map from source tensor in the input graph to result #.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> results_;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
// The outside_compilation clusters in this subgraph.
std::unordered_map<string, OutsideCompilationSubgraph>
@@ -583,8 +560,8 @@ class Encapsulator {
const string& dst_outside_compilation_id,
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out,
- std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
- edges_added);
+ std::unordered_set<std::pair<OutputTensor, InputTensor>,
+ OutputInputTensorPairHasher>* edges_added);
// Adds control dependencies between subgraph call nodes that have
// dependencies via outside_compilation edges.
@@ -716,11 +693,11 @@ void TopologicalClusterSort(
Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; }
int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
- return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input()));
+ return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input()));
}
int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
- return results_.at(NodeSlot(edge->src(), edge->src_output()));
+ return results_.at(OutputTensor(edge->src(), edge->src_output()));
}
Node* Encapsulator::Subgraph::GetRecvAtHostNode(
@@ -732,7 +709,7 @@ Node* Encapsulator::Subgraph::GetRecvAtHostNode(
int Encapsulator::Subgraph::GetRecvAtHostSlot(
const string& outside_compilation_subgraph_name, const Edge* edge) const {
return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
- .inputs.at(NodeSlot(edge->src(), edge->src_output()));
+ .inputs.at(OutputTensor(edge->src(), edge->src_output()));
}
Node* Encapsulator::Subgraph::GetSendFromHostNode(
@@ -744,7 +721,7 @@ Node* Encapsulator::Subgraph::GetSendFromHostNode(
int Encapsulator::Subgraph::GetSendFromHostSlot(
const string& outside_compilation_subgraph_name, const Edge* edge) const {
return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
- .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input()));
+ .outputs_by_dst.at(InputTensor(edge->dst(), edge->dst_input()));
}
Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
@@ -769,10 +746,10 @@ Status Encapsulator::Subgraph::RecordArg(
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
Node* src_node = edge->src();
int src_slot = edge->src_output();
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
bool inserted;
- std::tie(iter, inserted) =
- args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size());
+ std::tie(iter, inserted) = args_by_src_.emplace(
+ OutputTensor(src_node, src_slot), args_by_src_.size());
int arg_index = iter->second;
if (inserted) {
NodeDef arg_def;
@@ -793,7 +770,7 @@ Status Encapsulator::Subgraph::RecordArg(
Node* dst_node = edge->dst();
Node* dst_image = node_images.at(dst_node);
int dst_slot = edge->dst_input();
- args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index;
+ args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index;
graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
return Status::OK();
}
@@ -804,10 +781,10 @@ Status Encapsulator::Subgraph::RecordResult(
Node* src_node = edge->src();
Node* src_image = node_images.at(src_node);
int src_slot = edge->src_output();
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
bool inserted;
std::tie(iter, inserted) =
- results_.emplace(NodeSlot(src_node, src_slot), results_.size());
+ results_.emplace(OutputTensor(src_node, src_slot), results_.size());
int ret_index = iter->second;
if (inserted) {
NodeDef ret_def;
@@ -845,8 +822,8 @@ void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl(
outside_subgraph->control_inputs.insert(edge->src());
} else {
int input_index = outside_subgraph->inputs.size();
- outside_subgraph->inputs.emplace(NodeSlot(edge->src(), edge->src_output()),
- input_index);
+ outside_subgraph->inputs.emplace(
+ OutputTensor(edge->src(), edge->src_output()), input_index);
}
}
@@ -860,11 +837,13 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
DataType dtype = edge->dst()->input_type(edge->dst_input());
auto output_iter =
outside_subgraph->outputs_by_src
- .emplace(NodeSlot(edge->src(), edge->src_output(), dtype),
- outside_subgraph->outputs_by_src.size())
+ .emplace(OutputTensor(edge->src(), edge->src_output()),
+ OutsideCompilationSubgraph::ArgNumAndType(
+ outside_subgraph->outputs_by_src.size(), dtype))
.first;
- int output_index = output_iter->second;
- outside_subgraph->outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] =
+ const int output_index = output_iter->second.index;
+ outside_subgraph
+ ->outputs_by_dst[InputTensor(edge->dst(), edge->dst_input())] =
output_index;
}
}
@@ -946,7 +925,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
for (const auto& input_src : oc_subgraph.inputs) {
const Node* src_node = input_src.first.node;
Node* src_image = node_images.at(src_node);
- int src_slot = input_src.first.slot;
+ int src_slot = input_src.first.index;
int input_index = input_src.second;
DataType dtype = src_node->output_type(src_slot);
@@ -954,8 +933,8 @@ Status Encapsulator::Subgraph::AddHostComputes(
input_dtypes[input_index] = dtype;
}
for (const auto& output : oc_subgraph.outputs_by_src) {
- DataType dtype = output.first.dtype;
- int output_index = output.second;
+ DataType dtype = output.second.dtype;
+ int output_index = output.second.index;
output_dtypes[output_index] = dtype;
}
@@ -993,7 +972,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
for (auto& input_src : oc_subgraph.inputs) {
const Node* src_node = input_src.first.node;
Node* src_image = node_images.at(src_node);
- int src_slot = input_src.first.slot;
+ int src_slot = input_src.first.index;
int input_index = input_src.second;
graph_->AddEdge(src_image, src_slot, host_compute, input_index);
}
@@ -1015,7 +994,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
for (const auto& output : oc_subgraph.outputs_by_dst) {
const Node* dst_node = output.first.node;
Node* dst_image = node_images.at(dst_node);
- int dst_slot = output.first.slot;
+ int dst_slot = output.first.index;
int output_index = output.second;
graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
@@ -1226,7 +1205,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
for (const auto& input : oc_subgraph->inputs) {
const Node* src_node = input.first.node;
- int src_slot = input.first.slot;
+ int src_slot = input.first.index;
int input_index = input.second;
DataType dtype = src_node->output_type(src_slot);
@@ -1280,8 +1259,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
for (const auto& output : oc_subgraph->outputs_by_src) {
const Node* src_node = output.first.node;
Node* src_image = node_images.at(src_node);
- int src_slot = output.first.slot;
- int output_index = output.second;
+ int src_slot = output.first.index;
+ int output_index = output.second.index;
DataType dtype = src_node->output_type(src_slot);
dtypes[output_index] = dtype;
@@ -1680,8 +1659,8 @@ Status Encapsulator::CopyEdgeToOutputGraph(
const string& src_outside_compilation_id, const string& dst_func_id,
const string& dst_outside_compilation_id,
const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
- std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
- edges_added) {
+ std::unordered_set<std::pair<OutputTensor, InputTensor>,
+ OutputInputTensorPairHasher>* edges_added) {
Node* src_image;
TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
src_func_id, src_outside_compilation_id, dst_func_id,
@@ -1696,7 +1675,8 @@ Status Encapsulator::CopyEdgeToOutputGraph(
if (edge->IsControlEdge()) {
// Add the control edge, if we have not already added it, using the images
// determined above (potentially call operators or RecvAtHost/SendFromHost).
- if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1))
+ if (edges_added
+ ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1))
.second) {
graph_out->AddControlEdge(src_image, dst_image);
}
@@ -1714,8 +1694,8 @@ Status Encapsulator::CopyEdgeToOutputGraph(
// Add the edge, if we have not already added it.
if (edges_added
- ->emplace(NodeSlot(src_image, src_output),
- NodeSlot(dst_image, dst_input))
+ ->emplace(OutputTensor(src_image, src_output),
+ InputTensor(dst_image, dst_input))
.second) {
graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
}
@@ -1739,7 +1719,8 @@ Status Encapsulator::AddEdgesToOutputGraph(
// Set of edges already added to the output graph, represented as (src, dst)
// pairs. We use the set to deduplicate edges; multiple edges in the input
// graph may map to one edge in the output graph.
- std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>
+ std::unordered_set<std::pair<OutputTensor, InputTensor>,
+ OutputInputTensorPairHasher>
edges_added;
for (const Edge* edge : graph_in_->edges()) {
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 0f748515ef..568f0870c0 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
@@ -265,6 +266,28 @@ Status Node::input_node(int idx, const Node** const_n) const {
return Status::OK();
}
+// InputTensor
+
+bool InputTensor::operator==(const InputTensor& other) const {
+ return node == other.node && index == other.index;
+}
+
+uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
+ return Hash64Combine(std::hash<const Node*>()(s.node),
+ std::hash<int>()(s.index));
+}
+
+// OutputTensor
+
+bool OutputTensor::operator==(const OutputTensor& other) const {
+ return node == other.node && index == other.index;
+}
+
+uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
+ return Hash64Combine(std::hash<const Node*>()(s.node),
+ std::hash<int>()(s.index));
+}
+
// Graph
Graph::Graph(const OpRegistryInterface* ops)
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 33fb7cb57a..a147c94689 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -284,6 +284,16 @@ struct InputTensor {
InputTensor(const Node* n, int i) : node(n), index(i) {}
InputTensor() : node(nullptr), index(0) {}
+
+ // Returns true if this InputTensor is identical to 'other'. Nodes are
+ // compared using pointer equality.
+ bool operator==(const InputTensor& other) const;
+
+ // A hash function for InputTensors. Nodes are hashed based on their pointer
+ // value.
+ struct Hash {
+ uint64 operator()(InputTensor const& s) const;
+ };
};
// Represents an output of a node, i.e., the `index`-th output of `node`. Note
@@ -295,6 +305,16 @@ struct OutputTensor {
OutputTensor(const Node* n, int i) : node(n), index(i) {}
OutputTensor() : node(nullptr), index(0) {}
+
+ // Returns true if this OutputTensor is identical to 'other'. Nodes are
+ // compared using pointer equality.
+ bool operator==(const OutputTensor& other) const;
+
+ // A hash function for OutputTensors. Nodes are hashed based on their pointer
+ // value.
+ struct Hash {
+ uint64 operator()(OutputTensor const& s) const;
+ };
};
class Edge {