diff options
author | Ali Yahya <alive@google.com> | 2017-08-11 15:37:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-11 15:44:17 -0700 |
commit | 49f5fd91a47ce0578b19cb5a36865f3890dddb68 (patch) | |
tree | e9fc4b4cfca1154f1b1dd42cced34756d77ffb54 /tensorflow/python | |
parent | 656ed3d824b42e4f01ed0194d197b06a98850db5 (diff) |
TFE: Fix for bug in graph_only_ops.py.
PiperOrigin-RevId: 165034457
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/BUILD | 17 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_only_ops.py | 12 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_only_ops_test.py | 48 |
3 files changed, 72 insertions, 5 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index eafef04c77..e0e65c9dae 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -98,11 +98,11 @@ cuda_py_test( ":context", ":core", ":execute", - "//tensorflow/python:pywrap_tensorflow", ":tensor", ":test", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:framework_test_lib", ], ) @@ -197,6 +197,20 @@ py_library( deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_shape", + ], +) + +cuda_py_test( + name = "graph_only_ops_test", + srcs = ["graph_only_ops_test.py"], + additional_deps = [ + "graph_only_ops", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", ], ) @@ -242,6 +256,7 @@ py_library( ":context", ":core", ":execute", + ":graph_only_ops", ":tensor", ":test", "//tensorflow/python:pywrap_tensorflow", diff --git a/tensorflow/python/eager/graph_only_ops.py b/tensorflow/python/eager/graph_only_ops.py index bd7d08faed..77a9e7db20 100644 --- a/tensorflow/python/eager/graph_only_ops.py +++ b/tensorflow/python/eager/graph_only_ops.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape def graph_zeros_like(tensor): @@ -29,8 +30,8 @@ def graph_zeros_like(tensor): g = ops._get_graph_from_inputs([tensor]) # pylint: disable=protected-access with g.as_default(), ops.name_scope(None, "zeros_like", [tensor]) as name: tensor = ops.convert_to_tensor(tensor, name="tensor") - dtype = tensor.dtype.base_dtype.as_datatype_enum - dtype_value = attr_value_pb2.AttrValue(type=dtype) + dtype = tensor.dtype.base_dtype + dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum) op = g.create_op("ZerosLike", [tensor], [dtype], input_types=[dtype], attrs={"T": dtype_value}, name=name) result, = op.outputs @@ -39,8 +40,11 @@ def graph_zeros_like(tensor): def graph_placeholder(dtype, shape, name=None): """Graph-only version of tf.placeholder(), for internal use only.""" - dtype = dtype.base_dtype.as_datatype_enum - dtype_value = attr_value_pb2.AttrValue(type=dtype) + dtype = dtype.base_dtype + dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum) + if isinstance(shape, (list, tuple)): + shape = tensor_shape.TensorShape(shape) + assert isinstance(shape, tensor_shape.TensorShape) shape = attr_value_pb2.AttrValue(shape=shape.as_proto()) g = ops.get_default_graph() with ops.name_scope(name, "placeholder", []) as name: diff --git a/tensorflow/python/eager/graph_only_ops_test.py b/tensorflow/python/eager/graph_only_ops_test.py new file mode 100644 index 0000000000..d2a2b4e223 --- /dev/null +++ b/tensorflow/python/eager/graph_only_ops_test.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for graph_only_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.eager import graph_only_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class GraphOnlyOpsTest(test_util.TensorFlowTestCase): + + def testGraphZerosLike(self): + x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + z_tf = graph_only_ops.graph_zeros_like(x) + with self.test_session(): + self.assertAllClose(np.zeros((2, 3)), z_tf.eval()) + + def testGraphPlaceholder(self): + x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,)) + y_tf = math_ops.square(x_tf) + with self.test_session() as sess: + x = np.array([42]) + y = sess.run(y_tf, feed_dict={x_tf: np.array([42])}) + self.assertAllClose(np.square(x), y) + + +if __name__ == '__main__': + test.main() |