Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Example of loading a tensorflow model from scratch #223

Open
mencia opened this issue Jan 21, 2020 · 0 comments
Open

Example of loading a tensorflow model from scratch #223

mencia opened this issue Jan 21, 2020 · 0 comments

Comments

@mencia
Copy link

mencia commented Jan 21, 2020

Back in 2018 there was a discussion on how to load your own tensorflow model #34. Later a new way of doing it was suggested #152.

It would be very helpful if there was a minimal example where: 1) a model is built, 2) trained, 3) saved and 4) visualized. I figured out the first three steps, but I am stuck on the fourth. Below I will show the first three steps:

1. Build a VAE model

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tfd = tf.contrib.distributions

class VAE:

    """VAE"""

    def __init__(self, data):

        self.data = data
        self.loss = self.build_loss()
        self.sample = self.sample()

    def make_encoder(self, data, code_size):
      x = tf.layers.flatten(data)
      x = tf.layers.dense(x, 200, tf.nn.relu)
      x = tf.layers.dense(x, 200, tf.nn.relu)
      loc = tf.layers.dense(x, code_size)
      scale = tf.layers.dense(x, code_size, tf.nn.softplus)
      return tfd.MultivariateNormalDiag(loc, scale)

    def make_prior(self, code_size):
      loc = tf.zeros(code_size)
      scale = tf.ones(code_size)
      return tfd.MultivariateNormalDiag(loc, scale)

    def make_decoder(self, code, data_shape):
      x = code
      x = tf.layers.dense(x, 200, tf.nn.relu)
      x = tf.layers.dense(x, 200, tf.nn.relu)
      logit = tf.layers.dense(x, np.prod(data_shape))
      logit = tf.reshape(logit, [-1] + data_shape)
      return tfd.Independent(tfd.Bernoulli(logit), 2)

    def build_loss(self):
        """We sample the posterior to input the decoder"""
        prior = self.make_prior(code_size=2)
        posterior = self.make_encoder(self.data, code_size=2)
        code = posterior.sample()
        likelihood = self.make_decoder(code, [28, 28]).log_prob(self.data)
        divergence = tfd.kl_divergence(posterior, prior)
        elbo = tf.reduce_mean(likelihood - divergence)
        return -elbo

    def sample(self):
        """Decodes a random code"""
        prior = self.make_prior(code_size=2)
        return self.make_decoder(prior.sample(10), [28, 28]).mean()

2. Train the model

mnist = input_data.read_data_sets('MNIST_data/')
data = tf.placeholder(tf.float32, [None, 28, 28])
model = VAE(data)
loss = model.loss
optimize = tf.train.AdamOptimizer(0.001).minimize(loss)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(2):
        for _ in range(60):
            feed_dict = {data: mnist.train.next_batch(100)[0].reshape([-1, 28, 28])}
            sess.run(optimize, feed_dict)
            saver.save(sess, './logging/model_final')

3. Load the trained model and save it for lucid visualization

from lucid.modelzoo.vision_models import Model

with tf.Graph().as_default() as graph, tf.Session() as sess:
    
    path = './logging/'
    ckpt_state = tf.train.get_checkpoint_state(path)
    data = tf.placeholder(tf.float32, [None, 28, 28], name='images')
    model = VAE(data)
    saver = tf.train.Saver()
    saver.restore(sess, ckpt_state.model_checkpoint_path)
    
    Model.save("saved_model.pb",  
     input_name='images', 
     output_names=[graph.as_graph_def().node[-1].name], 
     image_shape=[28,28],
     image_value_range=[0,1])

4. Visualize

I get an error when trying to visualize it.

from lucid.modelzoo.vision_models import Model
import lucid.optvis.render as render

model = Model.load("saved_model.pb")
_ = render.render_vis(model, "dense_9/kernel:0")

The raised error:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~/env_py36/lib/python3.6/site-packages/tensorflow/python/framework/importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    426         results = c_api.TF_GraphImportGraphDefWithResults(
--> 427             graph._c_graph, serialized, options)  # pylint: disable=protected-access
    428         results = c_api_util.ScopedTFImportGraphDefResults(results)

InvalidArgumentError: Input 0 of node import/save/Assign was passed float from import/dense/bias:0 incompatible with expected float_ref.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-6-dc71ccdc8e67> in <module>
      1 import lucid.optvis.render as render
----> 2 _ = render.render_vis(model, "dense_9/kernel:0")

~/env_py36/lib/python3.6/site-packages/lucid/optvis/render.py in render_vis(model, objective_f, param_f, optimizer, transforms, thresholds, print_objectives, verbose, relu_gradient_override, use_fixed_seed)
     93 
     94     T = make_vis_T(model, objective_f, param_f, optimizer, transforms,
---> 95                    relu_gradient_override)
     96     print_objective_func = make_print_objective_func(print_objectives, T)
     97     loss, vis_op, t_image = T("loss"), T("vis_op"), T("input")

~/env_py36/lib/python3.6/site-packages/lucid/optvis/render.py in make_vis_T(model, objective_f, param_f, optimizer, transforms, relu_gradient_override)
    175     with gradient_override_map({'Relu': redirected_relu_grad,
    176                                 'Relu6': redirected_relu6_grad}):
--> 177       T = import_model(model, transform_f(t_image), t_image)
    178   else:
    179     T = import_model(model, transform_f(t_image), t_image)

~/env_py36/lib/python3.6/site-packages/lucid/optvis/render.py in import_model(model, t_image, t_image_raw, scope, input_map)
    255     t_image_raw = t_image
    256 
--> 257   model.import_graph(t_image, scope=scope, forget_xy_shape=True, input_map=input_map)
    258 
    259   def T(layer):

~/env_py36/lib/python3.6/site-packages/lucid/modelzoo/vision_base.py in import_graph(self, t_input, scope, forget_xy_shape, input_map)
    198       final_input_map.update(input_map)
    199     tf.import_graph_def(
--> 200         self.graph_def, final_input_map, name=scope)
    201     self.post_import(scope)
    202 

~/env_py36/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

~/env_py36/lib/python3.6/site-packages/tensorflow/python/framework/importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    429       except errors.InvalidArgumentError as e:
    430         # Convert to ValueError for backwards compatibility.
--> 431         raise ValueError(str(e))
    432 
    433     # Create _DefinedFunctions for any imported functions.

ValueError: Input 0 of node import/save/Assign was passed float from import/dense/bias:0 incompatible with expected float_ref.

Could someone provide a working minimal example please?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant