aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Ali Yahya <alive@google.com>2017-08-11 15:37:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-11 15:44:17 -0700
commit49f5fd91a47ce0578b19cb5a36865f3890dddb68 (patch)
treee9fc4b4cfca1154f1b1dd42cced34756d77ffb54 /tensorflow/python
parent656ed3d824b42e4f01ed0194d197b06a98850db5 (diff)
TFE: Fix for bug in graph_only_ops.py.
PiperOrigin-RevId: 165034457
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/BUILD17
-rw-r--r--tensorflow/python/eager/graph_only_ops.py12
-rw-r--r--tensorflow/python/eager/graph_only_ops_test.py48
3 files changed, 72 insertions, 5 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index eafef04c77..e0e65c9dae 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -98,11 +98,11 @@ cuda_py_test(
":context",
":core",
":execute",
- "//tensorflow/python:pywrap_tensorflow",
":tensor",
":test",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:framework_test_lib",
],
)
@@ -197,6 +197,20 @@ py_library(
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
+cuda_py_test(
+ name = "graph_only_ops_test",
+ srcs = ["graph_only_ops_test.py"],
+ additional_deps = [
+ "graph_only_ops",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
],
)
@@ -242,6 +256,7 @@ py_library(
":context",
":core",
":execute",
+ ":graph_only_ops",
":tensor",
":test",
"//tensorflow/python:pywrap_tensorflow",
diff --git a/tensorflow/python/eager/graph_only_ops.py b/tensorflow/python/eager/graph_only_ops.py
index bd7d08faed..77a9e7db20 100644
--- a/tensorflow/python/eager/graph_only_ops.py
+++ b/tensorflow/python/eager/graph_only_ops.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
def graph_zeros_like(tensor):
@@ -29,8 +30,8 @@ def graph_zeros_like(tensor):
g = ops._get_graph_from_inputs([tensor]) # pylint: disable=protected-access
with g.as_default(), ops.name_scope(None, "zeros_like", [tensor]) as name:
tensor = ops.convert_to_tensor(tensor, name="tensor")
- dtype = tensor.dtype.base_dtype.as_datatype_enum
- dtype_value = attr_value_pb2.AttrValue(type=dtype)
+ dtype = tensor.dtype.base_dtype
+ dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum)
op = g.create_op("ZerosLike", [tensor], [dtype], input_types=[dtype],
attrs={"T": dtype_value}, name=name)
result, = op.outputs
@@ -39,8 +40,11 @@ def graph_zeros_like(tensor):
def graph_placeholder(dtype, shape, name=None):
"""Graph-only version of tf.placeholder(), for internal use only."""
- dtype = dtype.base_dtype.as_datatype_enum
- dtype_value = attr_value_pb2.AttrValue(type=dtype)
+ dtype = dtype.base_dtype
+ dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum)
+ if isinstance(shape, (list, tuple)):
+ shape = tensor_shape.TensorShape(shape)
+ assert isinstance(shape, tensor_shape.TensorShape)
shape = attr_value_pb2.AttrValue(shape=shape.as_proto())
g = ops.get_default_graph()
with ops.name_scope(name, "placeholder", []) as name:
diff --git a/tensorflow/python/eager/graph_only_ops_test.py b/tensorflow/python/eager/graph_only_ops_test.py
new file mode 100644
index 0000000000..d2a2b4e223
--- /dev/null
+++ b/tensorflow/python/eager/graph_only_ops_test.py
@@ -0,0 +1,48 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for graph_only_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.eager import graph_only_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class GraphOnlyOpsTest(test_util.TensorFlowTestCase):
+
+ def testGraphZerosLike(self):
+ x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+ z_tf = graph_only_ops.graph_zeros_like(x)
+ with self.test_session():
+ self.assertAllClose(np.zeros((2, 3)), z_tf.eval())
+
+ def testGraphPlaceholder(self):
+ x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,))
+ y_tf = math_ops.square(x_tf)
+ with self.test_session() as sess:
+ x = np.array([42])
+ y = sess.run(y_tf, feed_dict={x_tf: np.array([42])})
+ self.assertAllClose(np.square(x), y)
+
+
+if __name__ == '__main__':
+ test.main()