diff options
author | 2016-08-23 13:33:04 -0800 | |
---|---|---|
committer | 2016-08-23 14:46:32 -0700 | |
commit | 011402d8987e753acd54c6251a7edd2e2d8155ba (patch) | |
tree | d718a73c7e5d2f340f1ab07079d2bce711b30237 /tensorflow/core/framework/shape_inference.cc | |
parent | d0550db5736f484c12ac7f52dfaf2aa581d3170f (diff) |
Allow a python shape inference fn to delegate to the cpp shape
inference function. Enable this for MatMul and SparseMatMul.
Change: 131097313
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 4300784ffe..c6da445165 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -49,6 +49,25 @@ InferenceContext::InferenceContext( InferenceContext::InferenceContext( const NodeDef* node_def, const OpDef& op_def, const std::vector<string>& input_shapes_string, + const std::vector<TensorShapeProto>& input_shapes, + const std::vector<const Tensor*>& input_tensors) + : node_def_(*CHECK_NOTNULL(node_def)) { + PreInputInit(op_def, input_tensors); + if (!construction_status_.ok()) return; + for (const TensorShapeProto& p : input_shapes) { + const Shape* shape; + construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); + if (!construction_status_.ok()) { + return; + } + inputs_.push_back(shape); + } + PostInputInit(); +} + +InferenceContext::InferenceContext( + const NodeDef* node_def, const OpDef& op_def, + const std::vector<string>& input_shapes_string, const std::vector<const Shape*>& input_shapes, const std::vector<const Tensor*>& input_tensors) : node_def_(*CHECK_NOTNULL(node_def)) { |