aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/Session.java
blob: 0d512978461dbb680c23ab0400a25a4d070ce0fe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow;

import java.util.ArrayList;
import java.util.List;

/**
 * Driver for {@link Graph} execution.
 *
 * <p>A {@code Session} instance encapsulates the environment in which {@link Operation}s in a
 * {@link Graph} are executed to compute {@link Tensor}s. For example:
 *
 * <pre>{@code
 * // Let's say graph is an instance of the Graph class
 * // for the computation y = 3 * x
 *
 * try (Session s = new Session(graph)) {
 *   try (Tensor x = Tensor.create(2.0f);
 *       Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
 *       System.out.println(y.floatValue());  // Will print 6.0f
 *   }
 *   try (Tensor x = Tensor.create(1.1f);
 *       Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
 *       System.out.println(y.floatValue());  // Will print 3.3f
 *   }
 * }
 * }</pre>
 *
 * <p><b>WARNING:</b>A {@code Session} owns resources that <b>must</b> be explicitly freed by
 * invoking {@link #close()}.
 *
 * <p>Instances of a Session are thread-safe.
 */
public final class Session implements AutoCloseable {

  /** Construct a new session with the associated {@link Graph}. */
  public Session(Graph g) {
    graph = g;
    Graph.Reference r = g.ref();
    try {
      nativeHandle = allocate(r.nativeHandle());
      graphRef = g.ref();
    } finally {
      r.close();
    }
  }

  /** Wrap an existing session with the associated {@link Graph}. */
  Session(Graph g, long nativeHandle) {
    graph = g;
    this.nativeHandle = nativeHandle;
    graphRef = g.ref();
  }

  /**
   * Release resources associated with the Session.
   *
   * <p>Blocks until there are no active executions ({@link Session.Runner#run()} calls). A Session
   * is not usable after close returns.
   */
  @Override
  public void close() {
    graphRef.close();
    synchronized (nativeHandleLock) {
      if (nativeHandle == 0) {
        return;
      }
      while (numActiveRuns > 0) {
        try {
          nativeHandleLock.wait();
        } catch (InterruptedException e) {
          Thread.currentThread().interrupt();
          // Possible leak of the Session and Graph in this case?
          return;
        }
      }
      delete(nativeHandle);
      nativeHandle = 0;
    }
  }

  /**
   * Run {@link Operation}s and evaluate {@link Tensor}s.
   *
   * <p>A Runner runs the necessary graph fragments to execute every {@link Operation} required to
   * evaluate the {@link Tensor}s to fetch. The {@link #feed(String,int,Tensor)} call allows callers
   * to override the value of {@link Tensor}s in the graph by substituing the provided {@link
   * Tensor}s for the outputs of the operations provided to {@link #feed(String,int,Tensor)}.
   */
  public final class Runner {
    /**
     * Avoid evaluating {@code operation} and substitute {@code t} for the value it produces.
     *
     * <p>This method is a shorthand for {@code feed(operation, 0, t)}.
     */
    public Runner feed(String operation, Tensor t) {
      return feed(operation, 0, t);
    }

    /**
     * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t}
     * for the value it produces.
     *
     * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
     * one {@code t} is being provided for.
     */
    public Runner feed(String operation, int index, Tensor t) {
      Operation op = operationByName(operation);
      if (op != null) {
        inputs.add(op.output(index));
        inputTensors.add(t);
      }
      return this;
    }

    /**
     * Make {@link #run()} return the output of {@code operation}.
     *
     * <p>This method is a shorthand for {@code fetch(operation, 0)}
     */
    public Runner fetch(String operation) {
      return fetch(operation, 0);
    }

    /**
     * Make {@link #run()} return the {@code index}-th output of {@code operation}.
     *
     * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
     * one to return.
     */
    public Runner fetch(String operation, int index) {
      Operation op = operationByName(operation);
      if (op != null) {
        outputs.add(op.output(index));
      }
      return this;
    }

    /**
     * Make {@link #run()} execute {@code operation}, but not return the evaluated {@link Tensor}.
     */
    public Runner addTarget(String operation) {
      Operation op = operationByName(operation);
      if (op != null) {
        targets.add(op);
      }
      return this;
    }

    /**
     * (Experimental method): set options (typically for debugging) for this run.
     *
     * <p>The options are presented as a serialized <a
     * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
     * protocol buffer</a>.
     *
     * <p>The org.tensorflow package is free of any protocol buffer dependencies in order to remain
     * friendly to resource constrained systems (where something like <a
     * href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may
     * be more appropriate). A cost of that is this lack of type-safety in this API function. This
     * choice is under review and this function may be replaced by more type-safe equivalents at any
     * time.
     */
    public Runner setOptions(byte[] options) {
      this.runOptions = options;
      return this;
    }

    /**
     * Execute the graph fragments necessary to compute all requested fetches.
     *
     * <p><b>WARNING:</b> The caller assumes ownership of all returned {@link Tensor}s, i.e., the
     * caller must call {@link Tensor#close()} on all elements of the returned list to free up
     * resources.
     *
     * <p>TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it
     * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in
     * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a
     * {@code Map<Output, Tensor>}?
     */
    public List<Tensor> run() {
      return runHelper(false).outputs;
    }

    /**
     * Execute graph fragments to compute requested fetches and return metadata about the run.
     *
     * <p>This is exactly like {@link #run()}, but in addition to the requested Tensors, also
     * returns metadata about the graph execution in the form of a serialized <a
     * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
     * protocol buffer</a>.
     */
    public Run runAndFetchMetadata() {
      return runHelper(true);
    }

    private Run runHelper(boolean wantMetadata) {
      long[] inputTensorHandles = new long[inputTensors.size()];
      long[] inputOpHandles = new long[inputs.size()];
      int[] inputOpIndices = new int[inputs.size()];
      long[] outputOpHandles = new long[outputs.size()];
      int[] outputOpIndices = new int[outputs.size()];
      long[] targetOpHandles = new long[targets.size()];
      long[] outputTensorHandles = new long[outputs.size()];

      // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
      // validity of the Graph and graphRef ensures that.
      int idx = 0;
      for (Tensor t : inputTensors) {
        inputTensorHandles[idx++] = t.getNativeHandle();
      }
      idx = 0;
      for (Output o : inputs) {
        inputOpHandles[idx] = o.op().getUnsafeNativeHandle();
        inputOpIndices[idx] = o.index();
        idx++;
      }
      idx = 0;
      for (Output o : outputs) {
        outputOpHandles[idx] = o.op().getUnsafeNativeHandle();
        outputOpIndices[idx] = o.index();
        idx++;
      }
      idx = 0;
      for (Operation op : targets) {
        targetOpHandles[idx++] = op.getUnsafeNativeHandle();
      }
      Reference runRef = new Reference();
      byte[] metadata = null;
      try {
        metadata =
            Session.run(
                nativeHandle,
                runOptions,
                inputTensorHandles,
                inputOpHandles,
                inputOpIndices,
                outputOpHandles,
                outputOpIndices,
                targetOpHandles,
                wantMetadata,
                outputTensorHandles);
      } finally {
        runRef.close();
      }
      List<Tensor> outputs = new ArrayList<Tensor>();
      for (long h : outputTensorHandles) {
        try {
          outputs.add(Tensor.fromHandle(h));
        } catch (Exception e) {
          for (Tensor t : outputs) {
            t.close();
          }
          outputs.clear();
          throw e;
        }
      }
      Run ret = new Run();
      ret.outputs = outputs;
      ret.metadata = metadata;
      return ret;
    }

    private class Reference implements AutoCloseable {
      public Reference() {
        synchronized (nativeHandleLock) {
          if (nativeHandle == 0) {
            throw new IllegalStateException("run() cannot be called on the Session after close()");
          }
          ++numActiveRuns;
        }
      }

      @Override
      public void close() {
        synchronized (nativeHandleLock) {
          if (nativeHandle == 0) {
            return;
          }
          if (--numActiveRuns == 0) {
            nativeHandleLock.notifyAll();
          }
        }
      }
    }

    private Operation operationByName(String opName) {
      Operation op = graph.operation(opName);
      if (op == null) {
        throw new IllegalArgumentException("No Operation named [" + opName + "] in the Graph");
      }
      return op;
    }

    private ArrayList<Output> inputs = new ArrayList<Output>();
    private ArrayList<Tensor> inputTensors = new ArrayList<Tensor>();
    private ArrayList<Output> outputs = new ArrayList<Output>();
    private ArrayList<Operation> targets = new ArrayList<Operation>();
    private byte[] runOptions = null;
  }

  /** Create a Runner to execute graph operations and evaluate Tensors. */
  public Runner runner() {
    return new Runner();
  }

  /**
   * Output tensors and metadata obtained when executing a session.
   *
   * <p>See {@link Runner#runAndFetchMetadata()}
   */
  public static final class Run {
    /** Tensors from requested fetches. */
    public List<Tensor> outputs;

    /**
     * (Experimental): Metadata about the run.
     *
     * <p>A serialized <a
     * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
     * protocol buffer</a>. The org.tensorflow package is free of any protocol buffer dependencies
     * in order to remain friendly to resource constrained systems (where something like <a
     * href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may
     * be more appropriate). A cost of that is this opaque blob. This choice is under review and
     * this field may be replaced by more type-safe equivalents at any time.
     */
    public byte[] metadata;
  }

  private final Graph graph;
  private final Graph.Reference graphRef;

  private final Object nativeHandleLock = new Object();
  private long nativeHandle;
  private int numActiveRuns;

  private static native long allocate(long graphHandle);

  private static native void delete(long handle);

  /**
   * Execute a session.
   *
   * <p>The author apologizes for the ugliness of the long argument list of this method. However,
   * take solace in the fact that this is a private method meant to cross the JNI boundary.
   *
   * @param handle to the C API TF_Session object (Session.nativeHandle)
   * @param runOptions serialized representation of a RunOptions protocol buffer, or null
   * @param inputOpHandles (see inputOpIndices)
   * @param inputOpIndices (see inputTensorHandles)
   * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values
   *     that are being "fed" (do not need to be computed) during graph execution.
   *     inputTensorHandles[i] (which correponds to a Tensor.nativeHandle) is considered to be the
   *     inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that
   *     inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
   * @param outputOpHandles (see outputOpIndices)
   * @param outputOpIndices together with outputOpHandles identifies the set of values that should
   *     be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is
   *     required that outputOpHandles.length == outputOpIndices.length.
   * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose
   *     output will not be returned
   * @param wantRunMetadata indicates whether metadata about this execution should be returned.
   * @param outputTensorHandles will be filled in with handles to the outputs requested. It is
   *     required that outputTensorHandles.length == outputOpHandles.length.
   * @return if wantRunMetadata is true, serialized representation of the RunMetadata protocol
   *     buffer, false otherwise.
   */
  private static native byte[] run(
      long handle,
      byte[] runOptions,
      long[] inputTensorHandles,
      long[] inputOpHandles,
      int[] inputOpIndices,
      long[] outputOpHandles,
      int[] outputOpIndices,
      long[] targetOpHandles,
      boolean wantRunMetadata,
      long[] outputTensorHandles);
}