diff options
Diffstat (limited to 'tensorflow/core/framework/dataset.cc')
-rw-r--r-- | tensorflow/core/framework/dataset.cc | 54 |
1 files changed, 20 insertions, 34 deletions
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index e886ef7b8e..f3c7189292 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -74,18 +74,18 @@ class DatasetVariantWrapper { } // namespace Status GraphDefBuilderWrapper::AddDataset( - const GraphDatasetBase* dataset, + const DatasetBase* dataset, const std::vector<std::pair<size_t, Node*>>& inputs, const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, const std::vector<std::pair<StringPiece, AttrValue>>& attrs, Node** output) { - const string& op_type_name = dataset->op_name(); + const string& name = dataset->name(); std::unique_ptr<const GraphDefBuilder::Options> opts( new GraphDefBuilder::Options(b_->opts())); // TODO(srbs|mrry): Not all datasets have output_types and output_shapes // attributes defined. It will be nice to have a consistent pattern. - bool has_output_types_attr = HasAttr(op_type_name, "output_types"); - bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes"); + bool has_output_types_attr = HasAttr(name, "output_types"); + bool has_output_shapes_attr = HasAttr(name, "output_shapes"); if (has_output_shapes_attr) { opts.reset(new GraphDefBuilder::Options( opts->WithAttr("output_shapes", dataset->output_shapes()))); @@ -102,8 +102,7 @@ Status GraphDefBuilderWrapper::AddDataset( return errors::Internal("AddDataset: Failed to build Options with error ", opts->StatusToString()); } - NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name, - opts->op_registry()); + NodeBuilder node_builder(opts->GetNameForOp(name), name, opts->op_registry()); { size_t total_size = inputs.size() + list_inputs.size(); auto inputs_iter = inputs.begin(); @@ -128,7 +127,7 @@ Status GraphDefBuilderWrapper::AddDataset( } *output = opts->FinalizeBuilder(&node_builder); if (*output == nullptr) { - return errors::Internal("AddDataset: Failed to build ", op_type_name, + return errors::Internal("AddDataset: Failed to build ", name, " op with error ", opts->StatusToString()); } return Status::OK(); @@ -184,27 +183,32 @@ void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val, b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val)); } -bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name, +bool GraphDefBuilderWrapper::HasAttr(const string& name, const string& attr_name) const { const OpDef* op_def = nullptr; - Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def); + Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def); if (!s.ok() || op_def == nullptr) { return false; } return HasAttr(op_def, attr_name); } -Status GraphDatasetBase::Serialize(SerializationContext* ctx, - string* serialized_graph_def, - string* output_node) const { +Status DatasetBase::Save(SerializationContext* ctx, + IteratorStateWriter* writer) const { + string serialized_graph_def; + string output_node; GraphDefBuilder b; DatasetGraphDefBuilder db(&b); Node* node = nullptr; TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node)); - *output_node = node->name(); + output_node = node->name(); GraphDef graph_def; TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); - graph_def.SerializeToString(serialized_graph_def); + graph_def.SerializeToString(&serialized_graph_def); + TF_RETURN_IF_ERROR( + writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); return Status::OK(); } @@ -264,8 +268,8 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx, MakeDataset(ctx, input, another_input, output); } -const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH"; -const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] = +const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH"; +const char DatasetBase::kDatasetGraphOutputNodeKey[] = "_DATASET_GRAPH_OUTPUT_NODE"; BackgroundWorker::BackgroundWorker(Env* env, const string& name) { @@ -315,22 +319,4 @@ void BackgroundWorker::WorkerLoop() { } } -namespace dataset { - -IteratorContext MakeIteratorContext(OpKernelContext* ctx) { - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); - params.lib = ctx->function_library(); - // Note: must use reinterpret_cast because function.h forward-declares Device. - DeviceBase* device = - reinterpret_cast<DeviceBase*>(ctx->function_library()->device()); - params.allocator_getter = [device](AllocatorAttributes attrs) { - return device->GetAllocator(attrs); - }; - return IteratorContext(params); -} - -} // namespace dataset - } // namespace tensorflow |