aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-10-03 13:54:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-03 13:58:54 -0700
commit0a11eaffc985ad6abd3a0e792061e1880766674a (patch)
treebefeeaba41b0716d791fd6d8701fb93b64680c24 /tensorflow/cc/ops
parent8fb14b1409e44b607dff5faa840e210a90fd586c (diff)
Internal Variant API allowing registering Variants to be copied from/to GPU.
Adds a test in the variant_op_copy_test. Modifies the base GPUDevice to use this registry if it sees a singleton variant. Modifies the rendezvous manager to do the same. PiperOrigin-RevId: 170908757
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r--tensorflow/cc/ops/const_op.cc25
-rw-r--r--tensorflow/cc/ops/const_op.h2
-rw-r--r--tensorflow/cc/ops/const_op_test.cc14
3 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/cc/ops/const_op.cc b/tensorflow/cc/ops/const_op.cc
index 0030c2b2a7..a04f37067d 100644
--- a/tensorflow/cc/ops/const_op.cc
+++ b/tensorflow/cc/ops/const_op.cc
@@ -19,19 +19,17 @@ limitations under the License.
namespace tensorflow {
namespace ops {
-Output Const(const Scope& scope, const Input::Initializer& val) {
+namespace {
+template <typename T>
+Output ConstHelper(const Scope& scope, const T& value, DataType dtype) {
if (!scope.ok()) return Output();
- if (!val.status.ok()) {
- scope.UpdateStatus(val.status);
- return Output();
- }
Node* ret;
Graph* graph = scope.graph();
const string unique_name = scope.GetUniqueNameForOp("Const");
auto builder = NodeBuilder(unique_name, "Const")
- .Attr("value", val.tensor)
- .Attr("dtype", val.tensor.dtype());
+ .Attr("value", value)
+ .Attr("dtype", dtype);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(graph, &ret));
if (!scope.ok()) return Output();
@@ -41,6 +39,19 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
return Output(ret);
}
+} // namespace
+
+Output Const(const Scope& scope, const Input::Initializer& val) {
+ if (!val.status.ok()) {
+ scope.UpdateStatus(val.status);
+ return Output();
+ }
+ return ConstHelper(scope, val.tensor, val.tensor.dtype());
+}
+
+Output ConstFromProto(const Scope& scope, const TensorProto& proto) {
+ return ConstHelper(scope, proto, proto.dtype());
+}
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp) {
if (!inp.status().ok()) {
diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h
index 516800920f..d11fda475b 100644
--- a/tensorflow/cc/ops/const_op.h
+++ b/tensorflow/cc/ops/const_op.h
@@ -28,6 +28,8 @@ namespace ops {
Output Const(const Scope& scope, const Input::Initializer& val);
+Output ConstFromProto(const Scope& scope, const TensorProto& proto);
+
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
template <typename T>
diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc
index 3184edeb33..69b5d7fd47 100644
--- a/tensorflow/cc/ops/const_op_test.cc
+++ b/tensorflow/cc/ops/const_op_test.cc
@@ -100,6 +100,20 @@ TEST(ConstOpTest, WithExplicitShape) {
ExpectNodeEqual<string>(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3});
}
+TEST(ConstOpTest, FromProto) {
+ Scope root = Scope::NewRootScope();
+ TensorProto proto;
+ proto.set_dtype(DT_DOUBLE);
+ TensorShape({2, 2}).AsProto(proto.mutable_tensor_shape());
+ for (int i = 0; i < 4; ++i) {
+ proto.add_double_val(static_cast<double>(i));
+ }
+ auto c = ops::ConstFromProto(root, proto);
+ TF_CHECK_OK(root.status());
+ EXPECT_EQ(c.op().output_type(0), DT_DOUBLE);
+ ExpectNodeEqual<double>(c.node(), {0.0, 1.0, 2.0, 3.0}, {2, 2});
+}
+
TEST(ConstOpTest, InvalidInitializer) {
Scope root = Scope::NewRootScope();
ops::Const(root, {{2.0}, {"df"}});