diff options
author | 2016-08-12 09:19:52 -0800 | |
---|---|---|
committer | 2016-08-12 10:32:56 -0700 | |
commit | 20a869a79b04eebd930b9085e32927b46fabd149 (patch) | |
tree | ee94181ae26a2c29e290469b61b4b6408a5488d8 /tensorflow/cc/ops | |
parent | 7baab1097c121947e01765fd1b063c62dfb34143 (diff) |
C++ API: Added a Const constructor for non-empty const supporting type cast.
Fixes #3752
Change: 130113000
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r-- | tensorflow/cc/ops/const_op.h | 37 | ||||
-rw-r--r-- | tensorflow/cc/ops/const_op_test.cc | 9 |
2 files changed, 33 insertions, 13 deletions
diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index 75844d124d..8976a24edc 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -25,22 +25,35 @@ namespace ops { Output Const(const Scope& scope, const Input::Initializer& val); +NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); + template <typename T> Output Const(const Scope& scope, const Input::Initializer& val) { + auto orig_const_output = Const(scope, val); if (!scope.ok()) return Output(); - if (!val.status.ok()) { - scope.UpdateStatus(val.status); - return Output(); - } + typedef typename Input::Initializer::RealType<T>::type DstT; - if (val.tensor.NumElements() > 0) { - // TODO(keveman): Implement the in-situ cast. - scope.UpdateStatus(errors::Unimplemented( - "Explict cast of a non-empty tensor not implemented yet")); - return Output(); + + if (val.tensor.dtype() == DataTypeToEnum<DstT>::v()) { + return orig_const_output; } - Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape()); - return Const(scope, Input::Initializer(t)); + if (val.tensor.NumElements() == 0) { + Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape()); + return Const(scope, Input::Initializer(t)); + } + + // TODO(keveman): Refactor Cast op's kernel implementation such that the code + // can be directly called here instead of adding the Cast op to the graph. + auto orig_const = AsNodeOut(scope, orig_const_output); + const auto cast_op_name = scope.GetUniqueNameForOp("Cast"); + + auto cast_builder = NodeBuilder(cast_op_name, "Cast") + .Input(orig_const) + .Attr("DstT", DataTypeToEnum<DstT>::v()); + scope.UpdateBuilder(&cast_builder); + Node* ret; + scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret)); + return Output(ret, 0); } template <typename T> @@ -54,8 +67,6 @@ Output Const(const Scope& scope, const std::initializer_list<T>& v, return Const(scope, Input::Initializer(v, shape)); } -NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); - std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope, const InputList& inp); diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index a56b66c1cc..5a4770f879 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -125,4 +125,13 @@ TEST(ConstOpTest, Names) { EXPECT_EQ(c_y_1.node()->name(), "c/y_1"); } +TEST(ConstOpTest, TemplatedConst) { + Scope root = Scope::NewRootScope(); + auto c1 = ops::Const<int>(root, {1, 2}); + ExpectTypeAndShape(c1.node(), DT_INT32, {2}); + + auto c2 = ops::Const<string>(root, {{"this"}, {"is"}, {"a"}, {"constant"}}); + ExpectTypeAndShape(c2.node(), DT_STRING, {4, 1}); +} + } // namespace tensorflow |