aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-28 16:10:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 16:17:55 -0700
commit478d370eb116ad2294134d75a886637a7d6da225 (patch)
tree279ef8e8a2c9abeeda583393a986f055b9be314c /tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
parenta98bac521406bedef3ff2b9af9564b21ddda4d82 (diff)
[tf.data] Use Graph instead of GraphDef/FunctionDef for vectorization transforms
PiperOrigin-RevId: 215011835
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc451
1 files changed, 273 insertions, 178 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index cb56b65985..cea667f668 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -14,13 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+#include <memory>
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
@@ -36,255 +40,346 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-using function_utils::FunctionDefTensorDesc;
-
namespace {
-void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
- const string& output_retval, const DataType t) {
- // Set to unknown shape
- TensorShapeProto tensor_shape_proto;
- PartialTensorShape().AsProto(&tensor_shape_proto);
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> TensorDesc;
- function_utils::AddFunctionOutputWithUniqueName(
- "vectorized_out", output_retval, map_defun_fn, t);
+const char* const kRetValOp = "_Retval";
- *(*map_defun_node->mutable_attr())["output_shapes"]
- .mutable_list()
- ->add_shape() = tensor_shape_proto;
- (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
+ Graph* graph) {
+ // NOTE: We need two for loops here because we can't mutate the set of output
+ // edges as we iterate over them.
+ std::vector<const Edge*> edges_to_replace;
+ for (auto edge : old_src.first->out_edges()) {
+ if (edge->src_output() == old_src.second) {
+ edges_to_replace.push_back(edge);
+ }
+ }
+ for (auto edge : edges_to_replace) {
+ graph->AddEdge(new_src.first, new_src.second, edge->dst(),
+ edge->dst_input());
+ graph->RemoveEdge(edge);
+ }
}
-void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, int output_position) {
- DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
- << "Trying to remove output that doesn't exist. Output number: "
- << output_position;
+Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
+ const TensorDesc& output) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DataType type = output.first->output_type(output.second);
+ int index = map_defun_fn->ret_nodes.size();
- int num_later_outputs =
- map_defun_fn->signature().output_arg_size() - output_position - 1;
+ NodeDef ret_node_def;
+ ret_node_def.set_name("map_out");
+ ret_node_def.set_op(kRetValOp);
+ AddNodeAttr("T", type, &ret_node_def);
+ AddNodeAttr("index", index, &ret_node_def);
- // Remove from map_defun_fn's ret dict and output args
- map_defun_fn->mutable_ret()->erase(
- map_defun_fn->signature().output_arg(output_position).name());
- map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
- output_position, 1);
+ Status s;
+ Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
+ TF_RETURN_IF_ERROR(s);
- // Renumber outputs that come after
- for (int i = 0; i < num_later_outputs; ++i) {
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i + 1),
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i),
- outer_scope);
- }
- map_defun_node->mutable_attr()
- ->at("output_shapes")
- .mutable_list()
- ->mutable_shape()
- ->DeleteSubrange(output_position, 1);
- map_defun_node->mutable_attr()
- ->at("output_types")
- .mutable_list()
- ->mutable_type()
- ->ExtractSubrange(output_position, 1, nullptr);
+ map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
+ map_defun_fn->ret_nodes.push_back(ret_node);
+ map_defun_fn->ret_types.push_back(type);
+
+ return s;
}
-int FindOutputToConvert(const FunctionDef& function,
- const std::set<string>& unconvertible,
- FunctionDefTensorDesc* f) {
- for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
- const string& ret_key = function.signature().output_arg(i).name();
- *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
+ FunctionBody* map_defun_fn, Node* map_defun_node) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
- if (unconvertible.find(f->node_name) == unconvertible.end()) {
- return i;
- }
+ // Modify map_defun_fn's signature and remove the output node from its graph
+ map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
+ map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
+ output_position);
+ map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
+ output_position);
+
+ // Renumber the nodes and edges that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ ReplaceEdgeSources({map_defun_node, output_position + i + 1},
+ {map_defun_node, output_position + i}, outer_scope);
+ // Each ret node has an "index" attr that has to be updated
+ map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
+ output_position + i);
}
- return -1;
}
// Helper class that vectorizes the body of a MapDefun node, adding new
// operations to the graph that collectively compute the same value as what
// running the MapDefun function on slices of the input would produce.
-// Each instance of the class encapsulates all the data necessary to vectorize a
-// MapDefun op in place.
+// This class transforms the input FunctionDefs into their corresponding
+// Graph objects and works on the graphs directly, then converts them back
+// to FunctionDefs when GetResult is called.
class Vectorization {
public:
- Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node)
- : outer_scope_(outer_scope),
- map_defun_fn_(map_defun_fn),
- map_defun_node_(map_defun_node) {}
+ explicit Vectorization(FunctionDefLibrary* lib)
+ : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
- // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
- // the outer_scope_, until there are no convertible outputs remaining.
- // This method is idempotent.
- void Vectorize();
+ // Adds the vectorized function and new map_defun_fn to lib, and points
+ // vectorized_function to the former. Returns an error status if
+ // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
+ // along the way.
+ Status Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDef** result);
private:
- // Vectorizes the map defun function's output at output_position
- Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
- // Given a descriptor of the original output tensor, gets a string
- // corresponding to the converted output tensor.
- Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
- string* converted);
- Status AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc);
+ // Converts FunctionDefs to Graphs.
+ Status Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node);
+
+ // Converts Graphs back to FunctionDefs and adds them to `lib_`.
+ Status GetResult(FunctionDef** vectorized_function);
+
+ // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
+ // `outer_scope_`, until there are no convertible outputs remaining.
+ void VectorizeHelper();
+
+ // Vectorizes map_defun_fn's output at output_position.
+ Status ConvertOutput(int output_position);
// Adds mappings from node's outputs tensors to converted output tensors,
// creating the necessary new node(s). Generally, the steps to convert an op
// are:
- // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
- // and modify map_defun_node_ attrs accordingly
- // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
// These operations collectively compute the same value as what running
// the original operation on slices of the input tensors would produce.
// For example, a Cast op in MapDefun translates to a Cast op in
- // outer_scope_, since the vectorized version of Cast is itself.
- // 3) Set inputs of new node(s) to the corresponding converted inputs (that
- // are now outputs of map_defun_node_)
- // 4) For each output of the old node, add the mapping of output strings to
- // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
- Status AddConversionMappingFromOp(const NodeDef& node,
- const FunctionDefTensorDesc& output_desc);
-
- // Maps a tensor name to the name of the corresponding vectorized tensor. For
- // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
- std::map<string, string> conversion_map_;
- // Unconvertible node names
- std::set<string> unconvertible_;
-
- FunctionDef* outer_scope_;
- FunctionDef* map_defun_fn_;
- NodeDef* map_defun_node_;
+ // `outer_scope_`, since the vectorized version of Cast is itself.
+ // 2) Promote the inputs of the op inputs to outputs of the
+ // `map_defun_node_` and `map_defun_fn_`.
+ // 3) Add edges between the promoted inputs (that are now outputs of
+ // `map_defun_node`) and the inputs ports of the new node(s).
+ // 4) For each output of the old node, add the mapping of output tensors to
+ // the conversion map.
+ Status AddConversionMapping(Node* op_node);
+
+ // Maps a tensor to the corresponding vectorized tensor. For example,
+ // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0}
+ std::map<TensorDesc, TensorDesc> conversion_map_;
+
+ // Unconvertible ret nodes
+ std::set<Node*> unconvertible_;
+
+ FunctionDefLibrary* lib_; // Not owned
+ FunctionLibraryDefinition lib_def_;
+ // Note that FunctionBody has a pointer to a Graph object that corresponds
+ // to the function's subgraph, with additional kArgOp and kRetValOp nodes
+ // that denote that function arguments and return values. These nodes have the
+ // attrs "T" for the type, and "index" for the argument / retval index
+ // respectively. FunctionBody also keeps track of arg/ret_nodes and
+ // arg/ret_types, that should be ordered according to argument/output indices.
+ std::unique_ptr<Graph> outer_scope_;
+ std::unique_ptr<FunctionBody> map_defun_fn_;
+ Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+ Status status_;
};
-Status Vectorization::AddConversionMappingFromOp(
- const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
- for (const string& input_name : node.input()) {
- if (IsControlInput(input_name)) {
+Status Vectorization::AddConversionMapping(Node* op_node) {
+ for (auto edge : op_node->in_edges()) {
+ if (edge->IsControlEdge()) {
return errors::InvalidArgument(
"Vectorizing outputs with control inputs is currently not "
"supported.");
}
}
- // TODO(rachelim): Have some mechanism for registering converters and some
- // uniform, simpler way to represent them.
-
- DataTypeVector types;
- const OpDef* op_def = nullptr;
- TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
- TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
-
- std::vector<string> promoted_inputs;
- promoted_inputs.reserve(node.input_size());
- for (int i = 0; i < node.input_size(); ++i) {
- promoted_inputs.push_back(strings::StrCat(
- map_defun_node_->name(),
- ":output:", map_defun_fn_->signature().output_arg_size() + i));
- }
-
- auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
+ auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
if (vectorizer == nullptr) {
return errors::Unimplemented("No vectorizer registered for op: ",
- node.op());
+ op_node->type_string());
+ }
+ std::vector<Port> input_ports, output_ports;
+ input_ports.reserve(op_node->num_inputs());
+ output_ports.reserve(op_node->num_outputs());
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ &input_ports, &output_ports));
+
+ std::vector<const Edge*> input_edges;
+ TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
+
+ if (op_node->num_outputs() != output_ports.size() ||
+ op_node->num_inputs() != input_ports.size() ||
+ input_edges.size() != input_ports.size()) {
+ return errors::Internal("Vectorizer inputs/outputs don't match.");
}
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
- &conversion_map_));
+ // Promote the inputs of the op to MapDefun outputs and connect the edges
+ // accordingly.
+ for (size_t i = 0; i < op_node->num_inputs(); ++i) {
+ auto edge = input_edges[i];
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
+ input_ports[i].first, input_ports[i].second);
+ }
- // If we get here, the conversion was successful, so we promote the inputs
- // of the ops to MapDefun outputs.
- for (int i = 0; i < types.size(); ++i) {
- AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ // Add output mappings.
+ for (size_t i = 0; i < op_node->num_outputs(); ++i) {
+ conversion_map_.insert({{op_node, i}, std::move(output_ports[i])});
}
return Status::OK();
}
-Status Vectorization::AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc) {
- int input_index = function_utils::FindFunctionInputWithName(
- output_desc.node_name, *map_defun_fn_);
- if (input_index == -1) {
- return errors::Internal("Cannot convert non-existent input.");
+Status Vectorization::ConvertOutput(int output_position) {
+ // ret_edge->src() is the actual op that generated the retval, and
+ // ret_edge->dst() is the retval node whose op is "_Retval"
+ const Edge* ret_edge;
+ TF_RETURN_IF_ERROR(
+ map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
+
+ TensorDesc output({ret_edge->src(), ret_edge->src_output()});
+ TensorDesc converted_output;
+ if (auto found = gtl::FindOrNull(conversion_map_, output)) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ converted_output = *found;
+ } else {
+ TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
+ converted_output = conversion_map_.at(output);
}
- conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+ ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
+ outer_scope_.get());
+ RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
+ map_defun_node_);
+
return Status::OK();
}
-Status Vectorization::ConvertOutputHelper(
- const FunctionDefTensorDesc& output_desc, string* converted) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
- *converted = *found;
- return Status::OK();
+Status Vectorization::Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node,
+ FunctionDef** result) {
+ TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
+ VectorizeHelper();
+ return GetResult(result);
+}
+
+void Vectorization::VectorizeHelper() {
+ while (true) {
+ int output_position = graph_utils::GetFirstElementIndexWithPredicate(
+ [this](Node* n) {
+ return this->unconvertible_.find(n) == this->unconvertible_.end();
+ },
+ map_defun_fn_->ret_nodes);
+
+ // No outputs left to convert
+ if (output_position == -1) break;
+
+ Status s = ConvertOutput(output_position);
+ if (!s.ok()) {
+ Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
+ VLOG(2) << "Could not convert the output at node: "
+ << output_node->DebugString() << "\nError: " << s;
+ unconvertible_.insert(output_node);
+ }
}
- int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
- *map_defun_fn_);
- if (index == -1) { // The output comes from an input
- TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->ret_nodes.empty()) {
+ outer_scope_->RemoveNode(map_defun_node_);
} else {
- TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
- map_defun_fn_->node_def(index), output_desc));
+ // Update MapDefun node attrs accordingly
+ DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
+ map_defun_node_->AddAttr(
+ "output_shapes",
+ std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
+ map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
- *converted = conversion_map_.at(output_desc.full_str);
- return Status::OK();
}
+Status Vectorization::Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node) {
+ // Convert outer_scope and map_defun_fn to FunctionBodys so we can
+ // work on Graphs directly.
+ const FunctionDef* map_defun_fn =
+ lib_def_.Find(map_defun_node.attr().at("f").func().name());
+
+ if (map_defun_fn == nullptr) {
+ return errors::NotFound("Could not find function with name ",
+ map_defun_node.attr().at("f").func().name(),
+ " in function library.");
+ }
-Status Vectorization::ConvertOutput(int output_position,
- const FunctionDefTensorDesc& output_desc) {
- string converted_output_name;
- TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+ auto get_func_sig = [this](const string& op, const OpDef** sig) {
+ return this->lib_def_.LookUpOpDef(op, sig);
+ };
+
+ FunctionBody* outer_fn;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
+ get_func_sig, &outer_fn));
+ // We don't need outer_fn, just the graph
+ outer_scope_.reset(outer_fn->graph);
+ outer_fn->graph = nullptr;
+ delete outer_fn;
+
+ FunctionBody* tmp;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
+ get_func_sig, &tmp));
+ map_defun_fn_.reset(tmp);
+
+ // Find the MapDefun node in outer_scope_
+ int node_id = graph_utils::GetFirstElementIndexWithPredicate(
+ [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
+ outer_scope_->nodes());
+ if (node_id == -1) {
+ return errors::NotFound("Could not find node with name ",
+ map_defun_node.name(), " in outer_scope.");
+ }
+ map_defun_node_ = outer_scope_->FindNodeId(node_id);
+
+ // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to
+ // the conversion map
+ for (auto arg_node : map_defun_fn_->arg_nodes) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(
+ arg_node->attrs().Find("index")->i(), &input_node));
- // Remove the old output and make everything that referenced it point
- // to the new string
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node_->name(), ":output:", output_position),
- converted_output_name, outer_scope_);
- RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
- output_position);
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0}});
+ }
return Status::OK();
}
-void Vectorization::Vectorize() {
- while (true) {
- FunctionDefTensorDesc desc;
- int output_position =
- FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
- if (output_position == -1) break;
+Status Vectorization::GetResult(FunctionDef** vectorized_function) {
+ TF_RETURN_IF_ERROR(status_);
- if (!ConvertOutput(output_position, desc).ok()) {
- unconvertible_.insert(desc.node_name);
- }
- }
+ if (!map_defun_fn_->ret_nodes.empty()) {
+ FunctionDef* map_defun_fn = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
- // If we've converted all the outputs of the MapDefun function, we no longer
- // need the MapDefun node and can delete it.
- if (map_defun_fn_->signature().output_arg_size() == 0) {
- outer_scope_->mutable_node_def()->DeleteSubrange(
- function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
- *outer_scope_),
- 1);
+ AttrValue func_attr;
+ func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
+ map_defun_node_->AddAttr("f", func_attr);
}
- if (!unconvertible_.empty()) {
- VLOG(2) << "The following nodes could not be converted: ["
- << absl::StrJoin(unconvertible_, ", ") << "].";
- }
+ *vectorized_function = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
+ *vectorized_function);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *outer_scope_, (*vectorized_function)->signature().name(),
+ *vectorized_function));
+ return Status::OK();
}
+
} // namespace
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node) {
- Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result) {
+ *result = nullptr;
+ return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
}
} // end namespace vectorization_utils