Skip to content

Commit

Permalink
Merge pull request scikit-learn#2 from GaelVaroquaux/hmmc
Browse files Browse the repository at this point in the history
Hmmc
  • Loading branch information
lucidfrontier45 committed Mar 5, 2012
2 parents 88f35b4 + bd2405f commit e361207
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 38 deletions.
12 changes: 6 additions & 6 deletions doc/modules/hmm.rst
Expand Up @@ -15,7 +15,7 @@ the first order Markov Chain. It can be specified by the start probability
vector :math:`\boldsymbol{\Pi}` and the transition probability matrix
:math:`\mathbf{A}`.
The emission probability of observable can be any distribution with the
parameters :math:`\boldsymbol{\Theta}_i}` conditioned on the current hidden
parameters :math:`\boldsymbol{{\Theta}_i}` conditioned on the current hidden
state index. (e.g. Multinomial, Gaussian).
Thus the HMM can be completely determined by
:math:`\boldsymbol{\Pi, \mathbf{A}}` and :math:`\boldsymbol{{\Theta}_i}`.
Expand All @@ -42,8 +42,7 @@ See the ref listed below for further detailed information.

.. topic:: References:

.. [Rabiner89] `"A tutorial on hidden Markov models and selected applications in speech recognition"
<http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf>`_
[Rabiner89] `A tutorial on hidden Markov models and selected applications in speech recognition <http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf>`_
Lawrence, R. Rabiner, 1989


Expand Down Expand Up @@ -100,9 +99,10 @@ This time, the input is a single sequence of observed values.::

>>> model2 = hmm.GaussianHMM(3, "full")
>>> model2.fit([X])
GaussianHMM(covariance_type='full', covars_prior=0.01, covars_weight=1,
means_prior=None, means_weight=0, n_components=3, startprob=None,
startprob_prior=1.0, transmat=None, transmat_prior=1.0)
GaussianHMM(algorithm='viterbi', covariance_type='full', covars_prior=0.01,
covars_weight=1, means_prior=None, means_weight=0, n_components=3,
random_state=None, startprob=None, startprob_prior=1.0,
transmat=None, transmat_prior=1.0)
>>> Z2 = model.predict(X)


Expand Down
59 changes: 43 additions & 16 deletions examples/plot_hmm_sampling.py
Expand Up @@ -3,32 +3,59 @@
Demonstration of sampling from HMM
==================================
This script shows how to sample points from HMM.
This script shows how to sample points from a Hiden Markov Model (HMM):
we use a 4-components with specified mean and covariance.
The plot show the sequence of observations generated with the transitions
between them. We can see that, as specified by our transition matrix,
there are no transition between component 1 and 3.
"""

import numpy as np
from sklearn import hmm
import matplotlib.pyplot as plt

from sklearn import hmm

##############################################################
# prepareing parameters
startprob = np.array([0.6, 0.3, 0.1])
transmat = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2],
[0.2, 0.2, 0.6]])
means = np.array([[0.0, 0.0], [5.0, -1.0], [5.0, 10.0]])
covars = np.tile(np.identity(2), (3, 1, 1))

# build an HMM instance and set parameters
model = hmm.GaussianHMM(3, "full", startprob, transmat)
# Prepare parameters for a 3-components HMM
# Initial population probability
start_prob = np.array([0.6, 0.3, 0.1, 0.0])
# The transition matrix, note that there are no transitions possible
# between component 1 and 4
trans_mat = np.array([[0.7, 0.2, 0.0, 0.1],
[0.3, 0.5, 0.2, 0.0],
[0.0, 0.3, 0.5, 0.2],
[0.2, 0.0, 0.2, 0.6]])
# The means of each component
means = np.array([[0.0, 0.0],
[0.0, 11.0],
[9.0, 10.0],
[11.0, -1.0],
])
# The covariance of each component
covars = .5*np.tile(np.identity(2), (4, 1, 1))

# Build an HMM instance and set parameters
model = hmm.GaussianHMM(4, "full", start_prob, trans_mat,
random_state=42)

# Instead of fitting it from the data, we directly set the estimated
# parameters, the means and covariance of the components
model.means_ = means
model.covars_ = covars
###############################################################

# generate samples
# Generate samples
X, Z = model.sample(500)

#plot the sampled data
plt.plot(X[:, 0], X[:, 1], "-o", label="observable", ms=10,
mfc="orange", alpha=0.7)
plt.legend()
# Plot the sampled data
plt.plot(X[:, 0], X[:, 1], "-o", label="observations", ms=6,
mfc="orange", alpha=0.7)

# Indicate the component numbers
for i, m in enumerate(means):
plt.text(m[0], m[1], 'Component %i' % (i + 1),
size=17, horizontalalignment='center',
bbox=dict(alpha=.7, facecolor='w'))
plt.legend(loc='best')
plt.show()
56 changes: 42 additions & 14 deletions sklearn/hmm.py
Expand Up @@ -90,6 +90,9 @@ class _BaseHMM(BaseEstimator):
algorithm : string, one of the decoder_algorithms
decoder algorithm
random_state: RandomState or an int seed (0 by default)
A random number generator instance
See Also
--------
GMM : Gaussian mixture model
Expand All @@ -107,7 +110,8 @@ class _BaseHMM(BaseEstimator):
# the emission distribution parameters to expose them publically.

def __init__(self, n_components=1, startprob=None, transmat=None,
startprob_prior=None, transmat_prior=None, algorithm="viterbi"):
startprob_prior=None, transmat_prior=None,
algorithm="viterbi", random_state=None):
self.n_components = n_components

if startprob is None:
Expand All @@ -131,6 +135,7 @@ def __init__(self, n_components=1, startprob=None, transmat=None,
self._algorithm = algorithm
else:
self._algorithm = "viterbi"
self.random_state = random_state

def eval(self, obs):
"""Compute the log probability under the model and compute posteriors
Expand Down Expand Up @@ -327,12 +332,18 @@ def sample(self, n=1, random_state=None):
n : int
Number of samples to generate.
random_state: RandomState or an int seed (0 by default)
A random number generator instance. If None is given, the
object's random_state is used
Returns
-------
(obs, hidden_states)
obs : array_like, length `n` List of samples
hidden_states : array_like, length `n` List of hidden states
"""
if random_state is None:
random_state = self.random_state
random_state = check_random_state(random_state)

startprob_pdf = self.startprob_
Expand All @@ -345,14 +356,14 @@ def sample(self, n=1, random_state=None):
currstate = (startprob_cdf > rand).argmax()
hidden_states = [currstate]
obs = [self._generate_sample_from_state(
currstate, random_state=random_state)]
currstate, random_state=random_state)]

for _ in xrange(n - 1):
rand = random_state.rand()
currstate = (transmat_cdf[currstate] > rand).argmax()
hidden_states.append(currstate)
obs.append(self._generate_sample_from_state(
currstate, random_state=random_state))
currstate, random_state=random_state))

return np.array(obs), np.array(hidden_states, dtype=int)

Expand Down Expand Up @@ -601,14 +612,18 @@ class GaussianHMM(_BaseHMM):
(`n_components`, `n_features`) if 'diag',
(`n_components`, `n_features`, `n_features`) if 'full'
random_state: RandomState or an int seed (0 by default)
A random number generator instance
Examples
--------
>>> from sklearn.hmm import GaussianHMM
>>> GaussianHMM(n_components=2)
... #doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
GaussianHMM(algorithm='viterbi', covariance_type='diag', covars_prior=0.01,
covars_weight=1, means_prior=None, means_weight=0, n_components=2,
startprob=None, startprob_prior=1.0, transmat=None, transmat_prior=1.0)
random_state=None, startprob=None, startprob_prior=1.0, transmat=None,
transmat_prior=1.0)
See Also
--------
Expand All @@ -618,11 +633,13 @@ class GaussianHMM(_BaseHMM):
def __init__(self, n_components=1, covariance_type='diag', startprob=None,
transmat=None, startprob_prior=None, transmat_prior=None,
algorithm="viterbi", means_prior=None, means_weight=0,
covars_prior=1e-2, covars_weight=1):
covars_prior=1e-2, covars_weight=1,
random_state=None):
_BaseHMM.__init__(self, n_components, startprob, transmat,
startprob_prior=startprob_prior,
transmat_prior=transmat_prior,
algorithm=algorithm)
algorithm=algorithm,
random_state=random_state)

self._covariance_type = covariance_type
if not covariance_type in ['spherical', 'tied', 'diag', 'full']:
Expand Down Expand Up @@ -817,21 +834,26 @@ class MultinomialHMM(_BaseHMM):
emissionprob : array, shape ('n_components`, 'n_symbols`)
Probability of emitting a given symbol when in each state.
random_state: RandomState or an int seed (0 by default)
A random number generator instance
Examples
--------
>>> from sklearn.hmm import MultinomialHMM
>>> MultinomialHMM(n_components=2)
... #doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
MultinomialHMM(algorithm='viterbi', n_components=2, startprob=None,
startprob_prior=1.0, transmat=None, transmat_prior=1.0)
MultinomialHMM(algorithm='viterbi', n_components=2, random_state=None,
startprob=None, startprob_prior=1.0, transmat=None,
transmat_prior=1.0)
See Also
--------
GaussianHMM : HMM with Gaussian emissions
"""

def __init__(self, n_components=1, startprob=None, transmat=None,
startprob_prior=None, transmat_prior=None, algorithm="viterbi"):
startprob_prior=None, transmat_prior=None,
algorithm="viterbi", random_state=None):
"""Create a hidden Markov model with multinomial emissions.
Parameters
Expand All @@ -842,7 +864,8 @@ def __init__(self, n_components=1, startprob=None, transmat=None,
_BaseHMM.__init__(self, n_components, startprob, transmat,
startprob_prior=startprob_prior,
transmat_prior=transmat_prior,
algorithm=algorithm)
algorithm=algorithm,
random_state=random_state)

def _get_emissionprob(self):
"""Emission probability distribution for each state."""
Expand Down Expand Up @@ -919,6 +942,9 @@ class GMMHMM(_BaseHMM):
gmms : array of GMM objects, length `n_components`
GMM emission distributions for each state.
random_state: RandomState or an int seed (0 by default)
A random number generator instance
Examples
--------
>>> from sklearn.hmm import GMMHMM
Expand All @@ -928,8 +954,8 @@ class GMMHMM(_BaseHMM):
gmms=[GMM(covariance_type=None, min_covar=0.001, n_components=10,
random_state=None, thresh=0.01), GMM(covariance_type=None,
min_covar=0.001, n_components=10, random_state=None, thresh=0.01)],
n_components=2, n_mix=10, startprob=None, startprob_prior=1.0,
transmat=None, transmat_prior=1.0)
n_components=2, n_mix=10, random_state=None, startprob=None,
startprob_prior=1.0, transmat=None, transmat_prior=1.0)
See Also
--------
Expand All @@ -938,7 +964,8 @@ class GMMHMM(_BaseHMM):

def __init__(self, n_components=1, n_mix=1, startprob=None, transmat=None,
startprob_prior=None, transmat_prior=None, algorithm="viterbi",
gmms=None, covariance_type='diag', covars_prior=1e-2):
gmms=None, covariance_type='diag', covars_prior=1e-2,
random_state=None):
"""Create a hidden Markov model with GMM emissions.
Parameters
Expand All @@ -949,7 +976,8 @@ def __init__(self, n_components=1, n_mix=1, startprob=None, transmat=None,
_BaseHMM.__init__(self, n_components, startprob, transmat,
startprob_prior=startprob_prior,
transmat_prior=transmat_prior,
algorithm=algorithm)
algorithm=algorithm,
random_state=random_state)

# XXX: Hotfit for n_mix that is incompatible with the scikit's
# BaseEstimator API
Expand Down
4 changes: 2 additions & 2 deletions sklearn/mixture/gmm.py
Expand Up @@ -124,8 +124,8 @@ class GMM(BaseEstimator):
use. Must be one of 'spherical', 'tied', 'diag', 'full'.
Defaults to 'diag'.
rng : numpy.random object, optional
Must support the full numpy random number generator API.
random_state: RandomState or an int seed (0 by default)
A random number generator instance
min_covar : float, optional
Floor on the diagonal of the covariance matrix to prevent
Expand Down

0 comments on commit e361207

Please sign in to comment.