diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-10-03 13:54:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-03 13:58:54 -0700 |
commit | 0a11eaffc985ad6abd3a0e792061e1880766674a (patch) | |
tree | befeeaba41b0716d791fd6d8701fb93b64680c24 /tensorflow/cc/ops | |
parent | 8fb14b1409e44b607dff5faa840e210a90fd586c (diff) |
Internal Variant API allowing registering Variants to be copied from/to GPU.
Adds a test in the variant_op_copy_test.
Modifies the base GPUDevice to use this registry if it sees a singleton variant.
Modifies the rendezvous manager to do the same.
PiperOrigin-RevId: 170908757
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r-- | tensorflow/cc/ops/const_op.cc | 25 | ||||
-rw-r--r-- | tensorflow/cc/ops/const_op.h | 2 | ||||
-rw-r--r-- | tensorflow/cc/ops/const_op_test.cc | 14 |
3 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/cc/ops/const_op.cc b/tensorflow/cc/ops/const_op.cc index 0030c2b2a7..a04f37067d 100644 --- a/tensorflow/cc/ops/const_op.cc +++ b/tensorflow/cc/ops/const_op.cc @@ -19,19 +19,17 @@ limitations under the License. namespace tensorflow { namespace ops { -Output Const(const Scope& scope, const Input::Initializer& val) { +namespace { +template <typename T> +Output ConstHelper(const Scope& scope, const T& value, DataType dtype) { if (!scope.ok()) return Output(); - if (!val.status.ok()) { - scope.UpdateStatus(val.status); - return Output(); - } Node* ret; Graph* graph = scope.graph(); const string unique_name = scope.GetUniqueNameForOp("Const"); auto builder = NodeBuilder(unique_name, "Const") - .Attr("value", val.tensor) - .Attr("dtype", val.tensor.dtype()); + .Attr("value", value) + .Attr("dtype", dtype); scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(graph, &ret)); if (!scope.ok()) return Output(); @@ -41,6 +39,19 @@ Output Const(const Scope& scope, const Input::Initializer& val) { return Output(ret); } +} // namespace + +Output Const(const Scope& scope, const Input::Initializer& val) { + if (!val.status.ok()) { + scope.UpdateStatus(val.status); + return Output(); + } + return ConstHelper(scope, val.tensor, val.tensor.dtype()); +} + +Output ConstFromProto(const Scope& scope, const TensorProto& proto) { + return ConstHelper(scope, proto, proto.dtype()); +} NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp) { if (!inp.status().ok()) { diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index 516800920f..d11fda475b 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -28,6 +28,8 @@ namespace ops { Output Const(const Scope& scope, const Input::Initializer& val); +Output ConstFromProto(const Scope& scope, const TensorProto& proto); + NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); template <typename T> diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index 3184edeb33..69b5d7fd47 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -100,6 +100,20 @@ TEST(ConstOpTest, WithExplicitShape) { ExpectNodeEqual<string>(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3}); } +TEST(ConstOpTest, FromProto) { + Scope root = Scope::NewRootScope(); + TensorProto proto; + proto.set_dtype(DT_DOUBLE); + TensorShape({2, 2}).AsProto(proto.mutable_tensor_shape()); + for (int i = 0; i < 4; ++i) { + proto.add_double_val(static_cast<double>(i)); + } + auto c = ops::ConstFromProto(root, proto); + TF_CHECK_OK(root.status()); + EXPECT_EQ(c.op().output_type(0), DT_DOUBLE); + ExpectNodeEqual<double>(c.node(), {0.0, 1.0, 2.0, 3.0}, {2, 2}); +} + TEST(ConstOpTest, InvalidInitializer) { Scope root = Scope::NewRootScope(); ops::Const(root, {{2.0}, {"df"}}); |