diff options
Diffstat (limited to 'tensorflow/java/src')
13 files changed, 9 insertions, 720 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Input.java b/tensorflow/java/src/main/java/org/tensorflow/Input.java deleted file mode 100644 index dff3a45463..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/Input.java +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow; - -/** - * Interface implemented by operands of a TensorFlow operation. - * - * <p>Example usage: - * - * <pre>{@code - * // The "decodeJpeg" operation can be used as input to the "cast" operation - * Input decodeJpeg = ops.image().decodeJpeg(...); - * ops.math().cast(decodeJpeg, DataType.FLOAT); - * - * // The output "y" of the "unique" operation can be used as input to the "cast" operation - * Output y = ops.array().unique(...).y(); - * ops.math().cast(y, DataType.FLOAT); - * - * // The "split" operation can be used as input list to the "concat" operation - * Iterable<? extends Input> split = ops.array().split(...); - * ops.array().concat(0, split); - * }</pre> - */ -public interface Input { - - /** - * Returns the symbolic handle of a tensor. - * - * <p>Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is - * used to obtain a symbolic handle that represents the computation of the input. - * - * @see {@link OperationBuilder#addInput(Output)}. - */ - Output asOutput(); -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operation.java b/tensorflow/java/src/main/java/org/tensorflow/Operation.java index e7de603409..43dbaf125c 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Operation.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Operation.java @@ -79,7 +79,7 @@ public final class Operation { * * @param name identifier of the list of tensors (of which there may * be many) produced by this operation. - * @return the size of the list of Tensors produced by this named output. + * @returns the size of the list of Tensors produced by this named output. * @throws IllegalArgumentException if this operation has no output * with the provided name. */ @@ -97,28 +97,6 @@ public final class Operation { return new Output(this, idx); } - /** - * Returns the size of the given inputs list of Tensors for this operation. - * - * <p>An Operation has multiple named inputs, each of which contains either - * a single tensor or a list of tensors. This method returns the size of - * the list of tensors for a specific named input of the operation. - * - * @param name identifier of the list of tensors (of which there may - * be many) inputs to this operation. - * @returns the size of the list of Tensors produced by this named input. - * @throws IllegalArgumentException if this operation has no input - * with the provided name. - */ - public int inputListLength(final String name) { - Graph.Reference r = graph.ref(); - try { - return inputListLength(unsafeNativeHandle, name); - } finally { - r.close(); - } - } - long getUnsafeNativeHandle() { return unsafeNativeHandle; } @@ -154,8 +132,6 @@ public final class Operation { private static native int outputListLength(long handle, String name); - private static native int inputListLength(long handle, String name); - private static native long[] shape(long graphHandle, long opHandle, int output); private static native int dtype(long graphHandle, long opHandle, int output); diff --git a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java index 8f7559d39e..38ffa2a8e1 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java @@ -28,7 +28,7 @@ import java.nio.charset.Charset; * <pre>{@code * // g is a Graph instance. * try (Tensor c1 = Tensor.create(3.0f)) { - * g.opBuilder("Const", "MyConst") + * g.opBuilder("Constant", "MyConst") * .setAttr("dtype", c1.dataType()) * .setAttr("value", c1) * .build(); diff --git a/tensorflow/java/src/main/java/org/tensorflow/Output.java b/tensorflow/java/src/main/java/org/tensorflow/Output.java index 2e3f8d4eac..ab128c2b30 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Output.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Output.java @@ -20,11 +20,8 @@ package org.tensorflow; * * <p>An Output is a symbolic handle to a tensor. The value of the Tensor is computed by executing * the {@link Operation} in a {@link Session}. - * - * <p>By implementing the {@link Input} interface, instances of this class could also be passed - * directly in input to an operation. */ -public final class Output implements Input { +public final class Output { /** Handle to the idx-th output of the Operation {@code op}. */ public Output(Operation op, int idx) { @@ -52,11 +49,6 @@ public final class Output implements Input { return operation.dtype(index); } - @Override - public Output asOutput() { - return this; - } - private final Operation operation; private final int index; } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java index f73cded4e3..0d071e1674 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Session.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java @@ -125,7 +125,7 @@ public final class Session implements AutoCloseable { * <tt>operation_name:output_index</tt> , in which case this method acts like {@code * feed(operation_name, output_index)}. These colon-separated names are commonly used in the * {@code SignatureDef} protocol buffer messages that are included in {@link - * SavedModelBundle#metaGraphDef()}. + * SavedModelBundle.metaGraphDef()}. */ public Runner feed(String operation, Tensor t) { return feed(parseOutput(operation), t); @@ -165,7 +165,7 @@ public final class Session implements AutoCloseable { * <tt>operation_name:output_index</tt> , in which case this method acts like {@code * fetch(operation_name, output_index)}. These colon-separated names are commonly used in * the {@code SignatureDef} protocol buffer messages that are included in {@link - * SavedModelBundle#metaGraphDef()}. + * SavedModelBundle.metaGraphDef()}. */ public Runner fetch(String operation) { return fetch(parseOutput(operation)); diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java deleted file mode 100644 index 2e84cac1ac..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.op; - -import java.util.HashMap; -import java.util.Map; -import java.util.regex.Pattern; - -/** - * A class to manage scoped (hierarchical) names for operators. - * - * <p>{@code NameScope} manages hierarchical names where each component in the hierarchy is - * separated by a forward slash {@code '/'}. For instance, {@code nn/Const_72} or {@code - * nn/gradient/assign/init}. Each scope is a subtree in this hierarchy. - * - * <p>Use {@code NameScope} to group related operations within a hierarchy, which for example lets - * tensorboard coalesce nodes for better graph visualizations. - * - * <p>This class is package private, user code creates {@link Scope} which internally delegates - * calls to an underlying {@code NameScope}. - * - * <p>This class is <b>not</b> thread-safe. - */ -final class NameScope { - - NameScope withSubScope(String scopeName) { - checkPattern(NAME_REGEX, scopeName); - // Override with opName if it exists. - String actualName = (opName != null) ? opName : scopeName; - String newPrefix = fullyQualify(makeUnique(actualName)); - return new NameScope(newPrefix, null, null); - } - - NameScope withName(String name) { - checkPattern(NAME_REGEX, name); - // All context except for the opName is shared with the new scope. - return new NameScope(opPrefix, name, ids); - } - - String makeOpName(String name) { - checkPattern(NAME_REGEX, name); - // Override with opName if it exists. - String actualName = (opName != null) ? opName : name; - return fullyQualify(makeUnique(actualName)); - } - - /** - * Create a new, root-level namescope. - * - * <p>A root-level namescope generates operator names with no components, like {@code Const_72} - * and {@code result}. - */ - NameScope() { - this(null, null, null); - } - - private NameScope(String opPrefix, String opName, Map<String, Integer> ids) { - this.opPrefix = opPrefix; - this.opName = opName; - if (ids != null) { - this.ids = ids; - } else { - this.ids = new HashMap<String, Integer>(); - } - } - - // Generate a unique name, different from existing ids. - // - // ids is a map from id to integer, representing a counter of the - // number of previous requests to generate a unique name for the - // given id. - // - // For instance, the first use of makeUnique("a") adds "a" -> 1 - // to ids and returns "a". - // - // The second use of makeUnique("a") updates ids to "a" -> 2 - // and returns "a_1", and so on. - private String makeUnique(String id) { - if (!ids.containsKey(id)) { - ids.put(id, 1); - return id; - } else { - int cur = ids.get(id); - ids.put(id, cur + 1); - return String.format("%s_%d", id, cur); - } - } - - private String fullyQualify(String name) { - if (opPrefix != null) { - return String.format("%s/%s", opPrefix, name); - } else { - return name; - } - } - - // If opPrefix is non-null, it is a prefix applied to all names - // created by this instance. - private final String opPrefix; - - // If opName is non-null, it is used to derive the unique name - // for operators rather than the provided default name. - private final String opName; - - // NameScope generates unique names by appending a numeric suffix if - // needed. This is a map containing names already created by this - // instance mapped to the next available numeric suffix for it. - private final Map<String, Integer> ids; - - private static void checkPattern(Pattern pattern, String name) { - if (name == null) { - throw new IllegalArgumentException("Names cannot be null"); - } - if (!pattern.matcher(name).matches()) { - throw new IllegalArgumentException( - String.format( - "invalid name: '%s' does not match the regular expression %s", - name, NAME_REGEX.pattern())); - } - } - - // The constraints for operator and scope names originate from restrictions on node names - // noted in the proto definition core/framework/node_def.proto for NodeDef and actually - // implemented in core/framework/node_def_util.cc [Note that the proto comment does not include - // dash (-) in names, while the actual implementation permits it. This regex follows the actual - // implementation.] - // - // This pattern is used to ensure fully qualified names always start with a LETTER_DIGIT_DOT, - // followed by zero or more LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE. SLASH is not permitted in - // actual user-supplied names to NameScope - it is used as a reserved character to separate - // subcomponents within fully qualified names. - private static final Pattern NAME_REGEX = Pattern.compile("[A-Za-z0-9.][A-Za-z0-9_.\\-]*"); -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java deleted file mode 100644 index 8de2eaeb79..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ /dev/null @@ -1,165 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.op; - -import org.tensorflow.Graph; - -/** - * Manages groups of related properties when creating Tensorflow Operations, such as a common name - * prefix. - * - * <p>A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user - * code initializes a {@code Scope} and provides it to Operation building classes. For example: - * - * <pre>{@code - * Scope scope = new Scope(graph); - * Constant c = Constant.create(scope, 42); - * }</pre> - * - * <p>An Operation building class acquires a Scope, and uses it to set properties on the underlying - * Tensorflow ops. For example: - * - * <pre>{@code - * // An operator class that adds a constant. - * public class Constant { - * public static Constant create(Scope scope, ...) { - * scope.graph().opBuilder( - * "Const", scope.makeOpName("Const")) - * .setAttr(...) - * .build() - * ... - * } - * } - * }</pre> - * - * <p><b>Scope hierarchy:</b> - * - * <p>A {@code Scope} provides various {@code with()} methods that create a new scope. The new scope - * typically has one property changed while other properties are inherited from the parent scope. - * - * <p>An example using {@code Constant} implemented as before: - * - * <pre>{@code - * Scope root = new Scope(graph); - * - * // The linear subscope will generate names like linear/... - * Scope linear = Scope.withSubScope("linear"); - * - * // This op name will be "linear/W" - * Constant.create(linear.withName("W"), ...); - * - * // This op will be "linear/Const", using the default - * // name provided by Constant - * Constant.create(linear, ...); - * - * // This op will be "linear/Const_1", using the default - * // name provided by Constant and making it unique within - * // this scope - * Constant.create(linear, ...); - * }</pre> - * - * <p>Scope objects are <b>not</b> thread-safe. - */ -public final class Scope { - - /** - * Create a new top-level scope. - * - * @param graph The graph instance to be managed by the scope. - */ - public Scope(Graph graph) { - this(graph, new NameScope()); - } - - /** Returns the graph managed by this scope. */ - public Graph graph() { - return graph; - } - - /** - * Returns a new scope where added operations will have the provided name prefix. - * - * <p>Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual - * name will be unique in the returned scope. All other properties are inherited from the current - * scope. - * - * <p>The child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} - * - * @param childScopeName name for the new child scope - * @return a new subscope - * @throws IllegalArgumentException if the name is invalid - */ - public Scope withSubScope(String childScopeName) { - return new Scope(graph, nameScope.withSubScope(childScopeName)); - } - - /** - * Return a new scope that uses the provided name for an op. - * - * <p>Operations created within this scope will have a name of the form {@code - * name/opName[_suffix]}. This lets you name a specific operator more meaningfully. - * - * <p>Names must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} - * - * @param opName name for an operator in the returned scope - * @return a new Scope that uses opName for operations. - * @throws IllegalArgumentException if the name is invalid - */ - public Scope withName(String opName) { - return new Scope(graph, nameScope.withName(opName)); - } - - /** - * Create a unique name for an operator, using a provided default if necessary. - * - * <p>This is normally called only by operator building classes. - * - * <p>This method generates a unique name, appropriate for the name scope controlled by this - * instance. Typical operator building code might look like - * - * <pre>{@code - * scope.graph().opBuilder("Const", scope.makeOpName("Const"))... - * }</pre> - * - * <p><b>Note:</b> if you provide a composite operator building class (i.e, a class that adds a - * set of related operations to the graph by calling other operator building code) you should also - * create a {@link #withSubScope(String)} scope for the underlying operators to group them under a - * meaningful name. - * - * <pre>{@code - * public static Stddev create(Scope scope, ...) { - * // group sub-operations under a common name - * Scope group = scope.withSubScope("stddev"); - * ... Sqrt.create(group, Mean.create(group, ...)) - * } - * }</pre> - * - * @param defaultName name for the underlying operator. - * @return unique name for the operator. - * @throws IllegalArgumentException if the default name is invalid. - */ - public String makeOpName(String defaultName) { - return nameScope.makeOpName(defaultName); - } - - private Scope(Graph graph, NameScope nameScope) { - this.graph = graph; - this.nameScope = nameScope; - } - - private final Graph graph; - private final NameScope nameScope; -} diff --git a/tensorflow/java/src/main/native/operation_jni.cc b/tensorflow/java/src/main/native/operation_jni.cc index ccc44d91c0..b3d5fc4ec3 100644 --- a/tensorflow/java/src/main/native/operation_jni.cc +++ b/tensorflow/java/src/main/native/operation_jni.cc @@ -156,21 +156,3 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv* env, return static_cast<jint>(TF_OperationOutputType(TF_Output{op, output_index})); } - -JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv* env, - jclass clazz, - jlong handle, - jstring name) { - TF_Operation* op = requireHandle(env, handle); - if (op == nullptr) return 0; - - TF_Status* status = TF_NewStatus(); - - const char* cname = env->GetStringUTFChars(name, nullptr); - int result = TF_OperationInputListLength(op, cname, status); - env->ReleaseStringUTFChars(name, cname); - - throwExceptionIfNotOK(env, status); - TF_DeleteStatus(status); - return result; -} diff --git a/tensorflow/java/src/main/native/operation_jni.h b/tensorflow/java/src/main/native/operation_jni.h index 6f379256d2..b5d156f7c2 100644 --- a/tensorflow/java/src/main/native/operation_jni.h +++ b/tensorflow/java/src/main/native/operation_jni.h @@ -73,17 +73,6 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape(JNIEnv *, JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv *, jclass, jlong, jlong, jint); - -/* - * Class: org_tensorflow_Operation - * Method: inputListLength - * Signature: (JLjava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv *, - jclass, - jlong, - jstring); - #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java index 4fa68130c0..74fdcf484e 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java @@ -52,16 +52,6 @@ public class OperationTest { assertEquals(3, split(new int[] {0, 1, 2}, 3)); } - @Test - public void inputListLength() { - assertEquals(1, splitWithInputList(new int[] {0, 1}, 1, "split_dim")); - try { - splitWithInputList(new int[] {0, 1}, 2, "inputs"); - } catch (IllegalArgumentException iae) { - // expected - } - } - private static int split(int[] values, int num_split) { try (Graph g = new Graph()) { return g.opBuilder("Split", "Split") @@ -72,15 +62,4 @@ public class OperationTest { .outputListLength("output"); } } - - private static int splitWithInputList(int[] values, int num_split, String name) { - try (Graph g = new Graph()) { - return g.opBuilder("Split", "Split") - .addInput(TestUtil.constant(g, "split_dim", 0)) - .addInput(TestUtil.constant(g, "values", values)) - .setAttr("num_split", num_split) - .build() - .inputListLength(name); - } - } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java index 50bdf351e3..0d2dbc5b88 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java @@ -109,7 +109,7 @@ public class SessionTest { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); - // Sanity check on metadata + // Sanity check on metadatar // See comments in fullTraceRunOptions() for an explanation about // why this check is really silly. Ideally, this would be: /* @@ -187,7 +187,7 @@ public class SessionTest { // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 // - // For this test, for now, the use of specific bytes suffices. + // For this test, for now, the use of specific bytes sufficies. return new byte[] {0x08, 0x03}; /* return org.tensorflow.framework.RunOptions.newBuilder() @@ -207,7 +207,7 @@ public class SessionTest { // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 // - // For this test, for now, the use of specific bytes suffices. + // For this test, for now, the use of specific bytes sufficies. return new byte[] {0x10, 0x01, 0x28, 0x01}; /* return org.tensorflow.framework.ConfigProto.newBuilder() diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java index 3ff59e71b2..44eecc1d1e 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java @@ -472,7 +472,7 @@ public class TensorTest { @Test public void fromHandle() { // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been - // created independently of the Java code. In practice, two Tensor instances MUST NOT have the + // created indepdently of the Java code. In practice, two Tensor instances MUST NOT have the // same native handle. // // An exception is made for this test, where the pitfalls of this is avoided by not calling diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java deleted file mode 100644 index 9256cb281d..0000000000 --- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java +++ /dev/null @@ -1,270 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.op; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.fail; - -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.tensorflow.Graph; -import org.tensorflow.Output; -import org.tensorflow.Session; -import org.tensorflow.Tensor; - -/** Unit tests for {@link org.tensorflow.Scope}. */ -@RunWith(JUnit4.class) -public class ScopeTest { - - @Test - public void basicNames() { - try (Graph g = new Graph()) { - Scope root = new Scope(g); - assertEquals("add", root.makeOpName("add")); - assertEquals("add_1", root.makeOpName("add")); - assertEquals("add_2", root.makeOpName("add")); - assertEquals("mul", root.makeOpName("mul")); - } - } - - @Test - public void hierarchicalNames() { - try (Graph g = new Graph()) { - Scope root = new Scope(g); - Scope child = root.withSubScope("child"); - assertEquals("child/add", child.makeOpName("add")); - assertEquals("child/add_1", child.makeOpName("add")); - assertEquals("child/mul", child.makeOpName("mul")); - - Scope child_1 = root.withSubScope("child"); - assertEquals("child_1/add", child_1.makeOpName("add")); - assertEquals("child_1/add_1", child_1.makeOpName("add")); - assertEquals("child_1/mul", child_1.makeOpName("mul")); - - Scope c_c = root.withSubScope("c").withSubScope("c"); - assertEquals("c/c/add", c_c.makeOpName("add")); - - Scope c_1 = root.withSubScope("c"); - Scope c_1_c = c_1.withSubScope("c"); - assertEquals("c_1/c/add", c_1_c.makeOpName("add")); - - Scope c_1_c_1 = c_1.withSubScope("c"); - assertEquals("c_1/c_1/add", c_1_c_1.makeOpName("add")); - } - } - - @Test - public void scopeAndOpNames() { - try (Graph g = new Graph()) { - Scope root = new Scope(g); - - Scope child = root.withSubScope("child"); - - assertEquals("child/add", child.makeOpName("add")); - assertEquals("child_1", root.makeOpName("child")); - assertEquals("child_2/p", root.withSubScope("child").makeOpName("p")); - } - } - - @Test - public void validateNames() { - try (Graph g = new Graph()) { - Scope root = new Scope(g); - - final String[] invalid_names = { - "_", "-", "-x", // Names are constrained to start with [A-Za-z0-9.] - null, "", "a$", // Invalid characters - "a/b", // slashes not allowed - }; - - for (String name : invalid_names) { - try { - root.withName(name); - fail("failed to catch invalid op name."); - } catch (IllegalArgumentException ex) { - // expected - } - // Subscopes follow the same rules - try { - root.withSubScope(name); - fail("failed to catch invalid scope name: " + name); - } catch (IllegalArgumentException ex) { - // expected - } - } - - // Unusual but valid names. - final String[] valid_names = {".", "..", "._-.", "a--."}; - - for (String name : valid_names) { - root.withName(name); - root.withSubScope(name); - } - } - } - - @Test - public void basic() { - try (Graph g = new Graph()) { - Scope s = new Scope(g); - Const c1 = Const.create(s, 42); - assertEquals("Const", c1.output().op().name()); - Const c2 = Const.create(s, 7); - assertEquals("Const_1", c2.output().op().name()); - Const c3 = Const.create(s.withName("four"), 4); - assertEquals("four", c3.output().op().name()); - Const c4 = Const.create(s.withName("four"), 4); - assertEquals("four_1", c4.output().op().name()); - } - } - - @Test - public void hierarchy() { - try (Graph g = new Graph()) { - Scope root = new Scope(g); - Scope child = root.withSubScope("child"); - assertEquals("child/Const", Const.create(child, 42).output().op().name()); - assertEquals("child/four", Const.create(child.withName("four"), 4).output().op().name()); - } - } - - @Test - public void composite() { - try (Graph g = new Graph(); - Session sess = new Session(g)) { - Scope s = new Scope(g); - Output data = Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output(); - - // Create a composite op with a customized name - Variance var1 = Variance.create(s.withName("example"), data); - assertEquals("example/variance", var1.output().op().name()); - - // Confirm internally added ops have the right names. - assertNotNull(g.operation("example/squared_deviation")); - assertNotNull(g.operation("example/Mean")); - assertNotNull(g.operation("example/zero")); - - // Same composite op with a default name - Variance var2 = Variance.create(s, data); - assertEquals("variance/variance", var2.output().op().name()); - - // Confirm internally added ops have the right names. - assertNotNull(g.operation("variance/squared_deviation")); - assertNotNull(g.operation("variance/Mean")); - assertNotNull(g.operation("variance/zero")); - - // Verify correct results as well. - Tensor result = sess.runner().fetch(var1.output()).run().get(0); - assertEquals(21704, result.intValue()); - result = sess.runner().fetch(var2.output()).run().get(0); - assertEquals(21704, result.intValue()); - } - } - - // "handwritten" sample operator classes - private static final class Const { - private final Output output; - - static Const create(Scope s, Object v) { - try (Tensor value = Tensor.create(v)) { - return new Const( - s.graph() - .opBuilder("Const", s.makeOpName("Const")) - .setAttr("dtype", value.dataType()) - .setAttr("value", value) - .build() - .output(0)); - } - } - - Const(Output o) { - output = o; - } - - Output output() { - return output; - } - } - - private static final class Mean { - private final Output output; - - static Mean create(Scope s, Output input, Output reductionIndices) { - return new Mean( - s.graph() - .opBuilder("Mean", s.makeOpName("Mean")) - .addInput(input) - .addInput(reductionIndices) - .build() - .output(0)); - } - - Mean(Output o) { - output = o; - } - - Output output() { - return output; - } - } - - private static final class SquaredDifference { - private final Output output; - - static SquaredDifference create(Scope s, Output x, Output y) { - return new SquaredDifference( - s.graph() - .opBuilder("SquaredDifference", s.makeOpName("SquaredDifference")) - .addInput(x) - .addInput(y) - .build() - .output(0)); - } - - SquaredDifference(Output o) { - output = o; - } - - Output output() { - return output; - } - } - - private static final class Variance { - private final Output output; - - static Variance create(Scope base, Output x) { - Scope s = base.withSubScope("variance"); - Output zero = Const.create(s.withName("zero"), new int[] {0}).output(); - Output sqdiff = - SquaredDifference.create( - s.withName("squared_deviation"), x, Mean.create(s, x, zero).output()) - .output(); - - return new Variance(Mean.create(s.withName("variance"), sqdiff, zero).output()); - } - - Variance(Output o) { - output = o; - } - - Output output() { - return output; - } - } -} |