1

I have a class representing a model that is set up as follows:

class Model:
  def __init__(self):
    self.setup_graph()

  def setup_graph():
    # sets up the model
    ....

  def train(self, dataset):
    # dataset is a tf.data.Dataset iterator, from which I can get 
    # tf.Tensor objects directly, which become part of the graph
    ....

  def predict(self, sample):
    # sample is a single NumPy array representing a sample,
    # which could be fed to a tf.placeholder using feed_dict
    ....

During training I want to make use of the efficiency of TensorFlow's tf.data.Dataset, but I still want to be able to get the output of the model on a single sample. It seems to me that this requires recreating the graph for prediction. Is this true, or can I create a TF graph where I can either run with a sample from a tf.data.Dataset, or with a given sample I feed to a tf.placeholder?

1 Answer 1

1

You can create your model with a dataset, iterator, etc as usual. Then, if you want to pass some custom data with feed_dict, you can just by passing values to the tensors produced by get_next():

import tensorflow as tf
import numpy as np

dataset = (tf.data.Dataset
    .from_tensor_slices(np.ones((100, 3), dtype=np.float32))
    .batch(5))
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()

output = 2 * batch

with tf.Session() as sess:
    print('From iterator:')
    print(sess.run(output))
    print('From feed_dict:')
    print(sess.run(output, feed_dict={batch: [[1, 2, 3]]}))

Output:

From iterator:
[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
From feed_dict:
[[2. 4. 6.]]

In principle you could achieve the same effect with initializable, reinitializable or feedable iterators, but if you really just want to test single samples of data I think this is the quickest and less intrusive way.

Sign up to request clarification or add additional context in comments.

3 Comments

Great, jdehesa, thank you! Is there a way I can do this the other way around, i.e. can I set it up so that I can pass self.train() a tf.data.Dataset object?
@MoosHueting You cannot pass a dataset (or an iterator, or its output) in the feed_dict. You can create different datasets and switch between them with a reinitializable or a feedable iterator (see Creating an iterator), or you can switch between a dataset and a placeholder or something else with tf.cond, or something like that. I am not sure what should be the best option for your use case.
Thanks @jdehesa. Feedable iterators turned out to be the solution for me.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.