aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/dataset.h
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2017-11-29 09:48:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-29 09:51:53 -0800
commit537ecc56cf09d5dcb2b328b322d9f8b195abcc6c (patch)
tree0256eec0ebe97d6c88720027beb579c5c40fd91b /tensorflow/core/kernels/dataset.h
parent667282eb0e62bef03bbe527bef88c656532444bb (diff)
[tf.data] Remove GraphDefBuilder and NodeBuilder dependencies from "dataset.h".
This is a step towards making a header-only library on which external op implementations can depend. To do this "dataset.h" cannot depend on any headers in "tensorflow/core/graph/...". PiperOrigin-RevId: 177322011
Diffstat (limited to 'tensorflow/core/kernels/dataset.h')
-rw-r--r--tensorflow/core/kernels/dataset.h155
1 files changed, 21 insertions, 134 deletions
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index afbebb0692..504a88a309 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -19,12 +19,13 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
-#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tracing.h"
@@ -59,6 +60,12 @@ class IteratorStateWriter {
virtual ~IteratorStateWriter() {}
};
+// Forward declarations to avoid introducing a dependency on headers in
+// "tensorflow/core/graph/...".
+class GraphDefBuilder;
+class GraphDatasetBase;
+class Node;
+
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
class GraphDefBuilderWrapper {
public:
@@ -110,10 +117,8 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- template <class DatasetType>
- Status AddDataset(const DatasetType* dataset,
- const std::vector<NodeBuilder::NodeOut>& inputs,
- Node** output) {
+ Status AddDataset(const GraphDatasetBase* dataset,
+ const std::vector<Node*>& inputs, Node** output) {
return AddDataset(dataset, inputs, {}, output);
}
@@ -125,77 +130,23 @@ class GraphDefBuilderWrapper {
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
// non-null if the method returns with an OK status.
// The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
- template <class DatasetType>
- Status AddDataset(const DatasetType* dataset,
- const std::vector<NodeBuilder::NodeOut>& inputs,
+ Status AddDataset(const GraphDatasetBase* dataset,
+ const std::vector<Node*>& inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
Node** output) {
- std::vector<std::pair<size_t, NodeBuilder::NodeOut>> enumerated_inputs(
- inputs.size());
+ std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
enumerated_inputs[i] = std::make_pair(i, inputs[i]);
}
return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
}
- template <class DatasetType>
Status AddDataset(
- const DatasetType* dataset,
- const std::vector<std::pair<size_t, NodeBuilder::NodeOut>>& inputs,
- const std::vector<
- std::pair<size_t, gtl::ArraySlice<NodeBuilder::NodeOut>>>&
- list_inputs,
+ const GraphDatasetBase* dataset,
+ const std::vector<std::pair<size_t, Node*>>& inputs,
+ const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
- Node** output) {
- const string& op_type_name = dataset->op_name();
- std::unique_ptr<const GraphDefBuilder::Options> opts(
- new GraphDefBuilder::Options(b_->opts()));
- // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
- // attributes defined. It will be nice to have a consistent pattern.
- bool has_output_types_attr = HasAttr(op_type_name, "output_types");
- bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
- if (has_output_shapes_attr) {
- opts.reset(new GraphDefBuilder::Options(
- opts->WithAttr("output_shapes", dataset->output_shapes())));
- }
- if (has_output_types_attr) {
- opts.reset(new GraphDefBuilder::Options(
- opts->WithAttr("output_types", dataset->output_dtypes())));
- }
- for (auto attr : attrs) {
- opts.reset(new GraphDefBuilder::Options(
- opts->WithAttr(attr.first, attr.second)));
- }
- if (opts->HaveError()) {
- return errors::Internal("AddDataset: Failed to build Options with error ",
- opts->StatusToString());
- }
- NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
- opts->op_registry());
- {
- size_t total_size = inputs.size() + list_inputs.size();
- auto inputs_iter = inputs.begin();
- auto list_inputs_iter = list_inputs.begin();
- for (int i = 0; i < total_size; i++) {
- if (inputs_iter != inputs.end() && inputs_iter->first == i) {
- node_builder.Input(inputs_iter->second);
- inputs_iter++;
- } else if (list_inputs_iter != list_inputs.end() &&
- list_inputs_iter->first == i) {
- node_builder.Input(list_inputs_iter->second);
- list_inputs_iter++;
- } else {
- return errors::InvalidArgument("No input found for index ", i);
- }
- }
- }
- *output = opts->FinalizeBuilder(&node_builder);
- if (*output == nullptr) {
- return errors::Internal("AddDataset: Failed to build ", op_type_name,
- " op with error ", opts->StatusToString());
- }
- return Status::OK();
- }
+ Node** output);
// Adds a user-defined function with name `function_name` to the graph and
// recursively adds all functions it references. If a function with a matching
@@ -203,50 +154,7 @@ class GraphDefBuilderWrapper {
// name `function_name` is not found in the FunctionLibraryDefinition, returns
// an InvalidArgumentError. If the function with name `function_name` or any
// of its dependent functions are stateful, returns an InvalidArgument error.
- Status AddFunction(OpKernelContext* ctx, const string& function_name) {
- if (b_->HasFunction(function_name)) {
- LOG(INFO) << "Function with name " << function_name << "already exists in"
- << " the graph. It will not be added again.";
- return Status::OK();
- }
- TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name));
- const FunctionLibraryDefinition* flib_def =
- ctx->function_library()->GetFunctionLibraryDefinition();
- const FunctionDef* f_def = flib_def->Find(function_name);
- if (f_def == nullptr) {
- return errors::InvalidArgument("Unable to find FunctionDef for ",
- function_name, " in the registry.");
- }
- FunctionDefLibrary def;
- *def.add_function() = *f_def;
- const string gradient_func = flib_def->FindGradient(function_name);
- if (!gradient_func.empty()) {
- GradientDef* g_def = def.add_gradient();
- g_def->set_function_name(function_name);
- g_def->set_gradient_func(gradient_func);
- }
- TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
-
- // Recursively add functions in inputs of function_name.
- for (const NodeDef& node_def : f_def->node_def()) {
- const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data));
- if (op_reg_data->is_function_op) {
- TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
- }
- // Recursively add functions in attrs of this NodeDef.
- for (const auto& pair : node_def.attr()) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx));
- }
- }
-
- // Recursively add functions in attrs of function_name.
- for (auto iter = f_def->attr().begin(); iter != f_def->attr().end();
- iter++) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx));
- }
- return Status::OK();
- }
+ Status AddFunction(OpKernelContext* ctx, const string& function_name);
template <typename T>
void BuildAttrValue(const T& value, AttrValue* attr) {
@@ -254,11 +162,7 @@ class GraphDefBuilderWrapper {
}
private:
- void AddTensorInternal(const Tensor& val, Node** output) {
- *output = ops::SourceOp(
- "Const",
- b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
- }
+ void AddTensorInternal(const Tensor& val, Node** output);
Status EnsureFunctionIsStateless(OpKernelContext* ctx,
const string& function_name) const {
@@ -294,14 +198,7 @@ class GraphDefBuilderWrapper {
HasAttr(op_def, "output_shapes");
}
- bool HasAttr(const string& op_type_name, const string& attr_name) const {
- const OpDef* op_def = nullptr;
- Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);
- if (!s.ok() || op_def == nullptr) {
- return false;
- }
- return HasAttr(op_def, attr_name);
- }
+ bool HasAttr(const string& op_type_name, const string& attr_name) const;
bool HasAttr(const OpDef* op_def, const string& attr_name) const {
for (auto attr : op_def->attr()) {
@@ -548,17 +445,7 @@ class GraphDatasetBase : public DatasetBase {
private:
Status Serialize(OpKernelContext* ctx, string* serialized_graph_def,
- string* output_node) const {
- GraphDefBuilder b;
- DatasetGraphDefBuilder db(&b);
- Node* node = nullptr;
- TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
- *output_node = node->name();
- GraphDef graph_def;
- TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
- graph_def.SerializeToString(serialized_graph_def);
- return Status::OK();
- }
+ string* output_node) const;
const string op_name_;
};