aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/shape_inference.h')
-rw-r--r--tensorflow/core/framework/shape_inference.h23
1 files changed, 22 insertions, 1 deletions
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index bb6a66dc53..6385177bc1 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -80,7 +82,10 @@ class InferenceContext {
// the same Dimension*.
//
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
- InferenceContext(const std::vector<string>& input_shapes, int num_outputs,
+ //
+ // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
+ InferenceContext(const NodeDef* node_def,
+ const std::vector<string>& input_shapes, int num_outputs,
const std::vector<const Tensor*>& input_tensors = {});
~InferenceContext();
@@ -162,6 +167,12 @@ class InferenceContext {
const Dimension* CreateDim(int64 value);
const Dimension* CreateUnknownDim();
+ // Look up the attr for the NodeDef being evaluated with name attr_name and
+ // set *value to its value. If no attr with attr_name is found in def(), or
+ // the attr does not have a matching type, a non-ok status will be returned.
+ template <class T>
+ Status GetAttr(StringPiece attr_name, T* value) const;
+
private:
Status ReturnUnknownShape(const Shape** out) {
*out = CreateUnknownShape();
@@ -181,9 +192,14 @@ class InferenceContext {
std::vector<const Tensor*> input_tensors_;
std::vector<const Shape*> outputs_;
+ const NodeDef& node_def_;
+
TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
};
+// -----------------------------------------------------------------------------
+// Template and inline method implementations, please ignore
+
inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
inline Dimension::Dimension(int64 value) : value_(value) {}
@@ -191,6 +207,11 @@ inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
inline Shape::Shape(const std::vector<const Dimension*> dims)
: rank_(dims.size()), dims_(dims) {}
+template <class T>
+Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
+ return GetNodeAttr(node_def_, attr_name, value);
+}
+
} // namespace shape_inference
} // namespace tensorflow