aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 4300784ffe..c6da445165 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -49,6 +49,25 @@ InferenceContext::InferenceContext(
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<string>& input_shapes_string,
+ const std::vector<TensorShapeProto>& input_shapes,
+ const std::vector<const Tensor*>& input_tensors)
+ : node_def_(*CHECK_NOTNULL(node_def)) {
+ PreInputInit(op_def, input_tensors);
+ if (!construction_status_.ok()) return;
+ for (const TensorShapeProto& p : input_shapes) {
+ const Shape* shape;
+ construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
+ if (!construction_status_.ok()) {
+ return;
+ }
+ inputs_.push_back(shape);
+ }
+ PostInputInit();
+}
+
+InferenceContext::InferenceContext(
+ const NodeDef* node_def, const OpDef& op_def,
+ const std::vector<string>& input_shapes_string,
const std::vector<const Shape*>& input_shapes,
const std::vector<const Tensor*>& input_tensors)
: node_def_(*CHECK_NOTNULL(node_def)) {