diff options
author | 2017-11-29 09:48:08 -0800 | |
---|---|---|
committer | 2017-11-29 09:51:53 -0800 | |
commit | 537ecc56cf09d5dcb2b328b322d9f8b195abcc6c (patch) | |
tree | 0256eec0ebe97d6c88720027beb579c5c40fd91b /tensorflow/core/kernels/dataset.h | |
parent | 667282eb0e62bef03bbe527bef88c656532444bb (diff) |
[tf.data] Remove GraphDefBuilder and NodeBuilder dependencies from "dataset.h".
This is a step towards making a header-only library on which external op
implementations can depend. To do this "dataset.h" cannot depend on any
headers in "tensorflow/core/graph/...".
PiperOrigin-RevId: 177322011
Diffstat (limited to 'tensorflow/core/kernels/dataset.h')
-rw-r--r-- | tensorflow/core/kernels/dataset.h | 155 |
1 files changed, 21 insertions, 134 deletions
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h index afbebb0692..504a88a309 100644 --- a/tensorflow/core/kernels/dataset.h +++ b/tensorflow/core/kernels/dataset.h @@ -19,12 +19,13 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" -#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/tracing.h" @@ -59,6 +60,12 @@ class IteratorStateWriter { virtual ~IteratorStateWriter() {} }; +// Forward declarations to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; +class GraphDatasetBase; +class Node; + // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. class GraphDefBuilderWrapper { public: @@ -110,10 +117,8 @@ class GraphDefBuilderWrapper { return Status::OK(); } - template <class DatasetType> - Status AddDataset(const DatasetType* dataset, - const std::vector<NodeBuilder::NodeOut>& inputs, - Node** output) { + Status AddDataset(const GraphDatasetBase* dataset, + const std::vector<Node*>& inputs, Node** output) { return AddDataset(dataset, inputs, {}, output); } @@ -125,77 +130,23 @@ class GraphDefBuilderWrapper { // `*output` contains a pointer to the output `Node`. It is guaranteed to be // non-null if the method returns with an OK status. // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - template <class DatasetType> - Status AddDataset(const DatasetType* dataset, - const std::vector<NodeBuilder::NodeOut>& inputs, + Status AddDataset(const GraphDatasetBase* dataset, + const std::vector<Node*>& inputs, const std::vector<std::pair<StringPiece, AttrValue>>& attrs, Node** output) { - std::vector<std::pair<size_t, NodeBuilder::NodeOut>> enumerated_inputs( - inputs.size()); + std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size()); for (int i = 0; i < inputs.size(); i++) { enumerated_inputs[i] = std::make_pair(i, inputs[i]); } return AddDataset(dataset, enumerated_inputs, {}, attrs, output); } - template <class DatasetType> Status AddDataset( - const DatasetType* dataset, - const std::vector<std::pair<size_t, NodeBuilder::NodeOut>>& inputs, - const std::vector< - std::pair<size_t, gtl::ArraySlice<NodeBuilder::NodeOut>>>& - list_inputs, + const GraphDatasetBase* dataset, + const std::vector<std::pair<size_t, Node*>>& inputs, + const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, const std::vector<std::pair<StringPiece, AttrValue>>& attrs, - Node** output) { - const string& op_type_name = dataset->op_name(); - std::unique_ptr<const GraphDefBuilder::Options> opts( - new GraphDefBuilder::Options(b_->opts())); - // TODO(srbs|mrry): Not all datasets have output_types and output_shapes - // attributes defined. It will be nice to have a consistent pattern. - bool has_output_types_attr = HasAttr(op_type_name, "output_types"); - bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes"); - if (has_output_shapes_attr) { - opts.reset(new GraphDefBuilder::Options( - opts->WithAttr("output_shapes", dataset->output_shapes()))); - } - if (has_output_types_attr) { - opts.reset(new GraphDefBuilder::Options( - opts->WithAttr("output_types", dataset->output_dtypes()))); - } - for (auto attr : attrs) { - opts.reset(new GraphDefBuilder::Options( - opts->WithAttr(attr.first, attr.second))); - } - if (opts->HaveError()) { - return errors::Internal("AddDataset: Failed to build Options with error ", - opts->StatusToString()); - } - NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name, - opts->op_registry()); - { - size_t total_size = inputs.size() + list_inputs.size(); - auto inputs_iter = inputs.begin(); - auto list_inputs_iter = list_inputs.begin(); - for (int i = 0; i < total_size; i++) { - if (inputs_iter != inputs.end() && inputs_iter->first == i) { - node_builder.Input(inputs_iter->second); - inputs_iter++; - } else if (list_inputs_iter != list_inputs.end() && - list_inputs_iter->first == i) { - node_builder.Input(list_inputs_iter->second); - list_inputs_iter++; - } else { - return errors::InvalidArgument("No input found for index ", i); - } - } - } - *output = opts->FinalizeBuilder(&node_builder); - if (*output == nullptr) { - return errors::Internal("AddDataset: Failed to build ", op_type_name, - " op with error ", opts->StatusToString()); - } - return Status::OK(); - } + Node** output); // Adds a user-defined function with name `function_name` to the graph and // recursively adds all functions it references. If a function with a matching @@ -203,50 +154,7 @@ class GraphDefBuilderWrapper { // name `function_name` is not found in the FunctionLibraryDefinition, returns // an InvalidArgumentError. If the function with name `function_name` or any // of its dependent functions are stateful, returns an InvalidArgument error. - Status AddFunction(OpKernelContext* ctx, const string& function_name) { - if (b_->HasFunction(function_name)) { - LOG(INFO) << "Function with name " << function_name << "already exists in" - << " the graph. It will not be added again."; - return Status::OK(); - } - TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name)); - const FunctionLibraryDefinition* flib_def = - ctx->function_library()->GetFunctionLibraryDefinition(); - const FunctionDef* f_def = flib_def->Find(function_name); - if (f_def == nullptr) { - return errors::InvalidArgument("Unable to find FunctionDef for ", - function_name, " in the registry."); - } - FunctionDefLibrary def; - *def.add_function() = *f_def; - const string gradient_func = flib_def->FindGradient(function_name); - if (!gradient_func.empty()) { - GradientDef* g_def = def.add_gradient(); - g_def->set_function_name(function_name); - g_def->set_gradient_func(gradient_func); - } - TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def)); - - // Recursively add functions in inputs of function_name. - for (const NodeDef& node_def : f_def->node_def()) { - const OpRegistrationData* op_reg_data = nullptr; - TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data)); - if (op_reg_data->is_function_op) { - TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name())); - } - // Recursively add functions in attrs of this NodeDef. - for (const auto& pair : node_def.attr()) { - TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx)); - } - } - - // Recursively add functions in attrs of function_name. - for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); - iter++) { - TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx)); - } - return Status::OK(); - } + Status AddFunction(OpKernelContext* ctx, const string& function_name); template <typename T> void BuildAttrValue(const T& value, AttrValue* attr) { @@ -254,11 +162,7 @@ class GraphDefBuilderWrapper { } private: - void AddTensorInternal(const Tensor& val, Node** output) { - *output = ops::SourceOp( - "Const", - b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val)); - } + void AddTensorInternal(const Tensor& val, Node** output); Status EnsureFunctionIsStateless(OpKernelContext* ctx, const string& function_name) const { @@ -294,14 +198,7 @@ class GraphDefBuilderWrapper { HasAttr(op_def, "output_shapes"); } - bool HasAttr(const string& op_type_name, const string& attr_name) const { - const OpDef* op_def = nullptr; - Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def); - if (!s.ok() || op_def == nullptr) { - return false; - } - return HasAttr(op_def, attr_name); - } + bool HasAttr(const string& op_type_name, const string& attr_name) const; bool HasAttr(const OpDef* op_def, const string& attr_name) const { for (auto attr : op_def->attr()) { @@ -548,17 +445,7 @@ class GraphDatasetBase : public DatasetBase { private: Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, - string* output_node) const { - GraphDefBuilder b; - DatasetGraphDefBuilder db(&b); - Node* node = nullptr; - TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node)); - *output_node = node->name(); - GraphDef graph_def; - TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); - graph_def.SerializeToString(serialized_graph_def); - return Status::OK(); - } + string* output_node) const; const string op_name_; }; |