diff options
author | 2017-05-16 16:08:20 -0700 | |
---|---|---|
committer | 2017-05-16 16:12:05 -0700 | |
commit | 749e5cc18381f7a5ec174673f76e20aead8529c6 (patch) | |
tree | 4b92d36c9e1d8e59e34fd8d08e7f11fbda1315d9 /tensorflow/cc | |
parent | ed5d05d8b53425ef98aad129a60143a5011a4288 (diff) |
Reduce direct references to NodeDef in favor of Node and AttrSlice
This is one step towards replacing in-memory use of NodeDef with a customized
NodeInfo class. There are still quite a few Node::def() references, but far fewer than before. Those remaining require more work, either because they are part of kernel registration (which is a bunch of functions), copy and modify the NodeDef, etc. Follow-on CLs will remove more.
RELNOTES: n/a
PiperOrigin-RevId: 156244933
Diffstat (limited to 'tensorflow/cc')
-rw-r--r-- | tensorflow/cc/framework/cc_op_gen.cc | 9 | ||||
-rw-r--r-- | tensorflow/cc/framework/cc_ops_test.cc | 17 | ||||
-rw-r--r-- | tensorflow/cc/framework/scope.cc | 4 | ||||
-rw-r--r-- | tensorflow/cc/gradients/array_grad.cc | 28 | ||||
-rw-r--r-- | tensorflow/cc/gradients/math_grad.cc | 8 | ||||
-rw-r--r-- | tensorflow/cc/ops/const_op_test.cc | 8 |
6 files changed, 39 insertions, 35 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 799492a4eb..71aa986f91 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -740,11 +740,10 @@ void OpInfo::GetOutput(string* out) const { return; } strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n"); - strings::StrAppend( - out, - " ::tensorflow::Status _status_ = " - "::tensorflow::NameRangesForNode(ret->def(), ret->op_def(), " - "nullptr, &_outputs_range);\n"); + strings::StrAppend(out, + " ::tensorflow::Status _status_ = " + "::tensorflow::NameRangesForNode(*ret, ret->op_def(), " + "nullptr, &_outputs_range);\n"); strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str, ".UpdateStatus(_status_);\n", " return;\n"); strings::StrAppend(out, " }\n\n"); diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index 92c97d107d..5da23036ea 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -35,8 +35,8 @@ Output Linear(const Scope& scope, Input x, Input w, Input b) { void GetColocationConstraints(const Output& tensor, std::vector<string>* constraints) { constraints->clear(); - TF_EXPECT_OK( - GetNodeAttr(tensor.op().node()->def(), kColocationAttrName, constraints)); + TF_EXPECT_OK(GetNodeAttr(tensor.op().node()->attrs(), kColocationAttrName, + constraints)); } } // namespace @@ -159,11 +159,11 @@ TEST(CCOpTest, KernelLabel) { Scope root = Scope::NewRootScope(); auto add = Add(root.WithKernelLabel("AddWithKernelLabel"), 1.0f, 2.0f); TF_EXPECT_OK(root.status()); - const auto& attrs = add.z.op().node()->def().attr(); - ASSERT_TRUE(attrs.find("_kernel") != attrs.end()); - auto kernel_attr = attrs.find("_kernel")->second; - TF_EXPECT_OK(AttrValueHasType(kernel_attr, "string")); - EXPECT_EQ(kernel_attr.s(), "AddWithKernelLabel"); + AttrSlice attrs = add.z.op().node()->attrs(); + const auto* kernel_attr = attrs.Find("_kernel"); + ASSERT_TRUE(kernel_attr); + TF_EXPECT_OK(AttrValueHasType(*kernel_attr, "string")); + EXPECT_EQ(kernel_attr->s(), "AddWithKernelLabel"); } TEST(CCOpTest, ColocateWith) { @@ -190,8 +190,7 @@ TEST(CCOpTest, ColocateWith) { Scope with_colocate = root.ColocateWith(c3).ColocateWith(c4); auto c6 = Const(with_colocate.WithOpName("c6").ClearColocation(), 7); - const auto& attrs = c6.op().node()->def().attr(); - EXPECT_TRUE(attrs.find("_class") == attrs.end()); + EXPECT_FALSE(c6.op().node()->attrs().Find("_class")); } TEST(CCOpTest, TemplatedConst) { diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 8b7fc1406f..32c0822de6 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -271,9 +271,9 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, std::unordered_set<string> Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { std::unordered_set<string> current_constraints(colocation_constraints_); - const NodeDef& node_def = colocate_with_op.node()->def(); + const AttrSlice attrs = colocate_with_op.node()->attrs(); std::vector<string> node_constraints; - if (GetNodeAttr(node_def, kColocationAttrName, &node_constraints).ok()) { + if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { for (const string& entry : node_constraints) { StringPiece s(entry); if (s.Consume(kColocationGroupPrefix)) { diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 26abd2438e..37f07e71a0 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -43,9 +43,9 @@ Status PackGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { int N; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N)); int axis; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); grad_outputs->reserve(N); auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis)); @@ -60,7 +60,7 @@ Status UnpackGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { int axis; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis))); return scope.status(); } @@ -162,7 +162,7 @@ Status CheckNumericsGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { string message; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "message", &message)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message)); string err_msg = strings::StrCat( "Not a number (NaN) or infinity (Inf) values detected in gradient. ", message); @@ -215,9 +215,9 @@ Status ReverseSequenceGrad(const Scope& scope, const Operation& op, std::vector<Output>* grad_outputs) { auto seq_lengths = op.input(1); int batch_dim; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim)); int seq_dim; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim)); grad_outputs->push_back( ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, ReverseSequence::BatchDim(batch_dim))); @@ -267,7 +267,8 @@ Status SpaceToBatchGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back( BatchToSpace(scope, grad_inputs[0], op.input(1), block_size)); grad_outputs->push_back(NoGradient()); @@ -290,7 +291,8 @@ Status BatchToSpaceGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back( SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size)); grad_outputs->push_back(NoGradient()); @@ -313,7 +315,8 @@ Status SpaceToDepthGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size)); return scope.status(); } @@ -323,7 +326,8 @@ Status DepthToSpaceGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size)); return scope.status(); } @@ -333,7 +337,7 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { string mode; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); @@ -346,7 +350,7 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { string mode; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); return scope.status(); diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 5a2c6d11fb..8c1a01f518 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -350,7 +350,7 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, const string& attr_adj_x, const string& attr_adj_y, std::vector<Output>* grad_outputs) { DataType dtype; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype)); if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { return errors::Unimplemented( "MatMul gradient for complex data type is not supported yet."); @@ -358,8 +358,10 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, bool ta; bool tb; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta)); - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb)); if (!ta && !tb) { return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index 5a4770f879..3184edeb33 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -28,9 +28,9 @@ void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values, TensorShape shape) { EXPECT_TRUE(n->IsConstant()); Tensor tensor; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); EXPECT_EQ(tensor.dtype(), dtype); test::ExpectTensorEqual<T>(tensor, test::AsTensor(values, shape)); } @@ -39,9 +39,9 @@ void ExpectTypeAndShape(const Node* n, DataType expected_dtype, TensorShape expected_shape) { EXPECT_TRUE(n->IsConstant()); Tensor tensor; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); EXPECT_EQ(dtype, expected_dtype); EXPECT_EQ(expected_shape, TensorShape(tensor.shape())); } |