diff --git a/tensorflow_probability/python/mcmc/diagnostic.py b/tensorflow_probability/python/mcmc/diagnostic.py index dd09b82020..0416e3daa1 100644 --- a/tensorflow_probability/python/mcmc/diagnostic.py +++ b/tensorflow_probability/python/mcmc/diagnostic.py @@ -160,13 +160,14 @@ def effective_sample_size(states, target_log_prob_fn=target.log_prob, step_size=0.05, num_leapfrog_steps=20)) - states.shape + print(states.shape) ==> (1000, 2) ess = effective_sample_size(states, filter_beyond_positive_pairs=True) - ==> Shape (2,) Tensor + print(ess.shape) + ==> (2,) - mean, variance = tf.nn.moments(states, axis=0) + mean, variance = tf.nn.moments(states, axes=0) standard_error = tf.sqrt(variance / ess) ```