aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@google.com>2016-08-12 09:19:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-12 10:32:56 -0700
commit20a869a79b04eebd930b9085e32927b46fabd149 (patch)
treeee94181ae26a2c29e290469b61b4b6408a5488d8 /tensorflow/cc/ops
parent7baab1097c121947e01765fd1b063c62dfb34143 (diff)
C++ API: Added a Const constructor for non-empty const supporting type cast.
Fixes #3752 Change: 130113000
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r--tensorflow/cc/ops/const_op.h37
-rw-r--r--tensorflow/cc/ops/const_op_test.cc9
2 files changed, 33 insertions, 13 deletions
diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h
index 75844d124d..8976a24edc 100644
--- a/tensorflow/cc/ops/const_op.h
+++ b/tensorflow/cc/ops/const_op.h
@@ -25,22 +25,35 @@ namespace ops {
Output Const(const Scope& scope, const Input::Initializer& val);
+NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
+
template <typename T>
Output Const(const Scope& scope, const Input::Initializer& val) {
+ auto orig_const_output = Const(scope, val);
if (!scope.ok()) return Output();
- if (!val.status.ok()) {
- scope.UpdateStatus(val.status);
- return Output();
- }
+
typedef typename Input::Initializer::RealType<T>::type DstT;
- if (val.tensor.NumElements() > 0) {
- // TODO(keveman): Implement the in-situ cast.
- scope.UpdateStatus(errors::Unimplemented(
- "Explict cast of a non-empty tensor not implemented yet"));
- return Output();
+
+ if (val.tensor.dtype() == DataTypeToEnum<DstT>::v()) {
+ return orig_const_output;
}
- Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
- return Const(scope, Input::Initializer(t));
+ if (val.tensor.NumElements() == 0) {
+ Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
+ return Const(scope, Input::Initializer(t));
+ }
+
+ // TODO(keveman): Refactor Cast op's kernel implementation such that the code
+ // can be directly called here instead of adding the Cast op to the graph.
+ auto orig_const = AsNodeOut(scope, orig_const_output);
+ const auto cast_op_name = scope.GetUniqueNameForOp("Cast");
+
+ auto cast_builder = NodeBuilder(cast_op_name, "Cast")
+ .Input(orig_const)
+ .Attr("DstT", DataTypeToEnum<DstT>::v());
+ scope.UpdateBuilder(&cast_builder);
+ Node* ret;
+ scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret));
+ return Output(ret, 0);
}
template <typename T>
@@ -54,8 +67,6 @@ Output Const(const Scope& scope, const std::initializer_list<T>& v,
return Const(scope, Input::Initializer(v, shape));
}
-NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
-
std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
const InputList& inp);
diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc
index a56b66c1cc..5a4770f879 100644
--- a/tensorflow/cc/ops/const_op_test.cc
+++ b/tensorflow/cc/ops/const_op_test.cc
@@ -125,4 +125,13 @@ TEST(ConstOpTest, Names) {
EXPECT_EQ(c_y_1.node()->name(), "c/y_1");
}
+TEST(ConstOpTest, TemplatedConst) {
+ Scope root = Scope::NewRootScope();
+ auto c1 = ops::Const<int>(root, {1, 2});
+ ExpectTypeAndShape(c1.node(), DT_INT32, {2});
+
+ auto c2 = ops::Const<string>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
+ ExpectTypeAndShape(c2.node(), DT_STRING, {4, 1});
+}
+
} // namespace tensorflow