diff options
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 4300784ffe..c6da445165 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -49,6 +49,25 @@ InferenceContext::InferenceContext( InferenceContext::InferenceContext( const NodeDef* node_def, const OpDef& op_def, const std::vector<string>& input_shapes_string, + const std::vector<TensorShapeProto>& input_shapes, + const std::vector<const Tensor*>& input_tensors) + : node_def_(*CHECK_NOTNULL(node_def)) { + PreInputInit(op_def, input_tensors); + if (!construction_status_.ok()) return; + for (const TensorShapeProto& p : input_shapes) { + const Shape* shape; + construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); + if (!construction_status_.ok()) { + return; + } + inputs_.push_back(shape); + } + PostInputInit(); +} + +InferenceContext::InferenceContext( + const NodeDef* node_def, const OpDef& op_def, + const std::vector<string>& input_shapes_string, const std::vector<const Shape*>& input_shapes, const std::vector<const Tensor*>& input_tensors) : node_def_(*CHECK_NOTNULL(node_def)) { |