diff options
Diffstat (limited to 'tensorflow/core/grappler/utils/grappler_test.h')
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test.h | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h index 3bc7bea454..e1394b9c35 100644 --- a/tensorflow/core/grappler/utils/grappler_test.h +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -57,6 +57,15 @@ class GrapplerTest : public ::testing::Test { // Count nodes of the given op-type in a graph. int CountOpNodes(const GraphDef& graph, const string& op); + // Get a random tansor with given shape. + template <DataType DTYPE> + Tensor GenerateRandomTensor(const TensorShape& shape) const { + typedef typename EnumToDataType<DTYPE>::Type T; + Tensor tensor(DTYPE, shape); + tensor.flat<T>() = tensor.flat<T>().random(); + return tensor; + } + private: SessionOptions options_; }; |