From 9cdc4ee4a52f61ffa258b97a8432e74dd5abc213 Mon Sep 17 00:00:00 2001 From: RenuPatelGoogle <89264621+RenuPatelGoogle@users.noreply.github.com> Date: Mon, 6 Jun 2022 19:33:05 +0530 Subject: [PATCH] Fixed few code lines in this API example Updated `axis` to `axes` as per tf.nn.moments() actual syntax and added full alias name (tfp.mcmc.) to effective_sample_size() to execute the example successfully. I have replicated and fixed the code error of this API example in this [gist](https://colab.sandbox.google.com/gist/RenuPatelGoogle/9a528802d6e52e46ad9713c6391bbc5c/tfp-mcmc-effective_sample_size.ipynb#scrollTo=pW_kQ2Px3VLs) for your reference. --- tensorflow_probability/python/mcmc/diagnostic.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow_probability/python/mcmc/diagnostic.py b/tensorflow_probability/python/mcmc/diagnostic.py index dd09b82020..0416e3daa1 100644 --- a/tensorflow_probability/python/mcmc/diagnostic.py +++ b/tensorflow_probability/python/mcmc/diagnostic.py @@ -160,13 +160,14 @@ def effective_sample_size(states, target_log_prob_fn=target.log_prob, step_size=0.05, num_leapfrog_steps=20)) - states.shape + print(states.shape) ==> (1000, 2) ess = effective_sample_size(states, filter_beyond_positive_pairs=True) - ==> Shape (2,) Tensor + print(ess.shape) + ==> (2,) - mean, variance = tf.nn.moments(states, axis=0) + mean, variance = tf.nn.moments(states, axes=0) standard_error = tf.sqrt(variance / ess) ```