diff --git a/doc/modules/clustering.rst b/doc/modules/clustering.rst index 41f904e1a4b9d..9119385011a03 100644 --- a/doc/modules/clustering.rst +++ b/doc/modules/clustering.rst @@ -472,6 +472,143 @@ by defining the adjusted Rand index as follows: `_ +Ajusted Mutual Information +-------------------------- + +Presentation and usage +~~~~~~~~~~~~~~~~~~~~~~ + +Given the knowledge of the ground truth class assignments ``labels_true`` +and our clustering algorithm assignments of the same samples +``labels_pred``, the **Adjusted Mutual Information** is a function that +measures the **agreement** of the two assignements, ignoring permutations + and **with chance normalization**:: + + >>> from sklearn import metrics + >>> labels_true = [0, 0, 0, 1, 1, 1] + >>> labels_pred = [0, 0, 1, 1, 2, 2] + + >>> metrics.adjusted_mutual_info_score(labels_true, labels_pred) # doctest: +ELLIPSIS + 0.24... + +One can permute 0 and 1 in the predicted labels and rename `2` by `3` and get +the same score:: + + >>> labels_pred = [1, 1, 0, 0, 3, 3] + >>> metrics.adjusted_mutual_info_score(labels_true, labels_pred) # doctest: +ELLIPSIS + 0.24... + +Furthermore, :func:`adjusted_mutual_info_score` is **symmetric**: swapping the +argument does not change the score. It can thus be used as a **consensus +measure**:: + + >>> metrics.adjusted_mutual_info_score(labels_pred, labels_true) # doctest: +ELLIPSIS + 0.24... + +Perfect labeling is scored 1.0:: + + >>> labels_pred = labels_true[:] + >>> metrics.adjusted_mutual_info_score(labels_true, labels_pred) + 1.0 + +Bad (e.g. independent labelings) have scores of zero:: + + >>> labels_true = [0, 1, 2, 0, 3, 4, 5, 1] + >>> labels_pred = [1, 1, 0, 0, 2, 2, 2, 2] + >>> metrics.adjusted_mutual_info_score(labels_true, labels_pred) # doctest: +ELLIPSIS + 0.0... + + +Advantages +~~~~~~~~~~ + +- **Random (uniform) label assignements have a AMI score close to 0.0** + for any value of ``n_clusters`` and ``n_samples`` (which is not the + case for raw Mutual Information or the V-measure for instance). + +- **Bounded range [0, 1]**: Values close to zero indicate two label + assignments that are largely independent, while values close to one + indicate significant agreement. Further, values of exactly 0 indicate + **purely** independent label assignments and a AMI of exactly 1 indicates + that the two label assignments are equal (with or without permutation). + +- **No assumption is made on the cluster structure**: can be used + to compare clustering algorithms such as k-means which assumes isotropic + blob shapes with results of spectral clustering algorithms which can + find cluster with "folded" shapes. + + +Drawbacks +~~~~~~~~~ + +- Contrary to inertia, **AMI requires the knowlege of the ground truth + classes** while almost never available in practice or requires manual + assignment by human annotators (as in the supervised learning setting). + + However AMI can also be useful in purely unsupervised setting as a + building block for a Consensus Index that can be used for clustering + model selection. + + +.. topic:: Examples: + + * :ref:`example_cluster_plot_adjusted_for_chance_measures.py`: Analysis of + the impact of the dataset size on the value of clustering measures + for random assignements. This example also includes the Adjusted Rand + Index. + + +Mathematical formulation +~~~~~~~~~~~~~~~~~~~~~~~~ +Assume two label assignments (of the same data), :math:`U` with :math:`R` +classes and :math:`V` with :math:`C` classes. The entropy of either is the + amount of uncertaintly for an array, and can be calculated as: + +.. math:: H(U) = \sum_{i=1}^{|R|}P(i)log(P(i)) + +Where P(i) is the number of instances in U that are in class :math:`R_i`. +Likewise, for :math:`V`: +.. math:: H(V) = \sum_{j=1}^{|C|}P'(j)log(P'(j)) +Where P'(j) is the number of instances in V that are in class :math:`C_j`. + +The (non-adjusted) mutual information between :math:`U` and :math:`V` is +calculated by: + +.. math:: MI(U, V) = \sum_{i=1}^{|R|}\sum_{j=1}^{|C|}P(i, j)log(\frac{P(i,j)}{P(i)P'(j)}) + +Where P(i, j) is the number of instances with label :math:`R_i` +and also with label :math:`C_j`. + +This value of the mutual information is not adjusted cfor chance and will tend +to increase as the number of different labels (clusters) increases, regardless +of the actual amount of "mutual information" between the label assignments. + +The expected value for the mutual information can be calculated using the +following equation, from Vinh, Epps, and Bailey, (2009). In this equation, +:math:`a_i` is the number of instances with label :math:`U_i` and +:math:`b_j` is the number of instances with label :math:`V_j`. + + +.. math:: E\{MI(U,V)\}=\sum_{i=1}^R \sum_{j=1}^C \sum_{n_{ij}=(a_i+b_j-N)^+ + }^{\min(a_i, b_j)} \frac{n_{ij}}{N}\log ( \frac{ N.n_{ij}}{a_i b_j}) + \frac{a_i!b_j!(N-a_i)!(N-b_j)!}{N!n_{ij}!(a_i-n_{ij})!(b_j-n_{ij})! + (N-a_i-b_j+n_{ij})!} + +Using the expected value, the adjusted mutual information can then be +calculated using a similar form to that of the adjusted Rand index: + +.. math:: AMI = \frac{MI - Expected\_MI}{max(H(U), H(V)) - Expected\_MI} + +.. topic:: References + + * Vinh, Epps, and Bailey, J. (2009). "Information theoretic measures + for clusterings comparison". Proceedings of the 26th Annual International + Conference on Machine Learning - ICML '09. + doi:10.1145/1553374.1553511. ISBN 9781605585161. + + * `Wikipedia entry for the Adjusted Mutual Information + `_ + Homogeneity, completeness and V-measure --------------------------------------- diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 7b8ee64432a69..4dbef627a3e57 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -19,7 +19,7 @@ Changelog - Faster tests by `Fabian Pedregosa`_. - Silhouette Coefficient cluster analysis evaluation metric added as - ``sklearn.metrics.silhouette_score`` by Robert Layton. + ``sklearn.metrics.silhouette_score`` by `Robert Layton`_. - Fixed a bug in `KMeans` in the handling of the `n_init` parameter: the clustering algorithm used to be run `n_init` times but the last @@ -28,6 +28,9 @@ Changelog - Minor refactoring in :ref:`sgd` module; consolidated dense and sparse predict methods. + - Adjusted Mutual Information metric added as + ``sklearn.metrics.adjusted_mutual_info_score`` by `Robert Layton`_. + API changes summary ------------------- diff --git a/examples/cluster/plot_adjusted_for_chance_measures.py b/examples/cluster/plot_adjusted_for_chance_measures.py index bed0dcc8758a1..cec0c1e739083 100644 --- a/examples/cluster/plot_adjusted_for_chance_measures.py +++ b/examples/cluster/plot_adjusted_for_chance_measures.py @@ -27,11 +27,12 @@ import numpy as np import pylab as pl +from time import time from sklearn import metrics def uniform_labelings_scores(score_func, n_samples, n_clusters_range, - fixed_n_classes=None, n_runs=10, seed=42): + fixed_n_classes=None, n_runs=5, seed=42): """Compute score for 2 random uniform cluster labelings. Both random labelings have the same number of clusters for each value @@ -58,6 +59,8 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, score_funcs = [ metrics.adjusted_rand_score, metrics.v_measure_score, + metrics.adjusted_mutual_info_score, + metrics.mutual_info_score, ] # 2 independent random clusterings with equal cluster number @@ -73,9 +76,12 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, print "Computing %s for %d values of n_clusters and n_samples=%d" % ( score_func.__name__, len(n_clusters_range), n_samples) + t0 = time() scores = uniform_labelings_scores(score_func, n_samples, n_clusters_range) + print "done in %0.3fs" % (time() - t0) plots.append(pl.errorbar( - n_clusters_range, scores.mean(axis=1), scores.std(axis=1))) + # n_clusters_range, scores.mean(axis=1), scores.std(axis=1))) + n_clusters_range, np.median(scores, axis=1), scores.std(axis=1))) names.append(score_func.__name__) pl.title("Clustering measures for 2 random uniform labelings\n" @@ -86,6 +92,7 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, pl.ylim(ymin=-0.05, ymax=1.05) pl.show() + # Random labeling with varying n_clusters against ground class labels # with fixed number of clusters @@ -101,8 +108,10 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, print "Computing %s for %d values of n_clusters and n_samples=%d" % ( score_func.__name__, len(n_clusters_range), n_samples) + t0 = time() scores = uniform_labelings_scores(score_func, n_samples, n_clusters_range, fixed_n_classes=n_classes) + print "done in %0.3fs" % (time() - t0) plots.append(pl.errorbar( n_clusters_range, scores.mean(axis=1), scores.std(axis=1))) names.append(score_func.__name__) diff --git a/examples/cluster/plot_affinity_propagation.py b/examples/cluster/plot_affinity_propagation.py index 73e8dd0271e3c..d3d5ff6eb33d0 100644 --- a/examples/cluster/plot_affinity_propagation.py +++ b/examples/cluster/plot_affinity_propagation.py @@ -40,6 +40,8 @@ print "V-measure: %0.3f" % metrics.v_measure_score(labels_true, labels) print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels_true, labels) +print "Adjusted Mutual Information: %0.3f" % \ + metrics.adjusted_mutual_info_score(labels_true, labels) D = (S / np.min(S)) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(D, labels, metric='precomputed')) diff --git a/examples/cluster/plot_color_quantization.py b/examples/cluster/plot_color_quantization.py index 518376e86b816..d2af736e242bb 100644 --- a/examples/cluster/plot_color_quantization.py +++ b/examples/cluster/plot_color_quantization.py @@ -10,7 +10,7 @@ In this example, pixels are represented in a 3D-space and K-means is used to find 64 color clusters. In the image processing literature, the codebook -obtained from K-means (the cluster centers) is called the color palette. Using a +obtained from K-means (the cluster centers) is called the color palette. Using a single byte, up to 256 colors can be addressed, whereas an RGB encoding requires 3 bytes per pixel. The GIF file format, for example, uses such a palette. @@ -61,7 +61,7 @@ print "done in %0.3fs." % (time() - t0) -codebook_random = shuffle(image_array, random_state=0)[:n_colors+1] +codebook_random = shuffle(image_array, random_state=0)[:n_colors + 1] print "Predicting color indices on the full image (random)" t0 = time() dist = euclidean_distances(codebook_random, image_array, squared=True) diff --git a/examples/cluster/plot_dbscan.py b/examples/cluster/plot_dbscan.py index 809965d6f21fc..e6a686e064c47 100644 --- a/examples/cluster/plot_dbscan.py +++ b/examples/cluster/plot_dbscan.py @@ -41,6 +41,8 @@ print "V-measure: %0.3f" % metrics.v_measure_score(labels_true, labels) print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels_true, labels) +print "Adjusted Mutual Information: %0.3f" % \ + metrics.adjusted_mutual_info_score(labels_true, labels) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(D, labels, metric='precomputed')) diff --git a/examples/cluster/plot_kmeans_digits.py b/examples/cluster/plot_kmeans_digits.py index c1982ce268fc7..811bfdd712b8e 100644 --- a/examples/cluster/plot_kmeans_digits.py +++ b/examples/cluster/plot_kmeans_digits.py @@ -48,8 +48,8 @@ print "V-measure: %0.3f" % metrics.v_measure_score(labels, km.labels_) print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels, km.labels_) -#print ("Silhouette Coefficient: %0.3f" % -# metrics.silhouette_score(D, km.labels_, metric='precomputed')) +print "Adjusted Mutual Information: %0.3f" % \ + metrics.adjusted_mutual_info_score(labels, km.labels_) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, km.labels_, metric='euclidean', sample_size=sample_size)) @@ -65,8 +65,8 @@ print "V-measure: %0.3f" % metrics.v_measure_score(labels, km.labels_) print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels, km.labels_) -#print ("Silhouette Coefficient: %0.3f" % -# metrics.silhouette_score(D, km.labels_, metric='precomputed')) +print "Adjusted Mutual Information: %0.3f" % \ + metrics.adjusted_mutual_info_score(labels, km.labels_) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, km.labels_, metric='euclidean', sample_size=sample_size)) @@ -85,8 +85,8 @@ print "V-measure: %0.3f" % metrics.v_measure_score(labels, km.labels_) print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels, km.labels_) -#print ("Silhouette Coefficient: %0.3f" % -# metrics.silhouette_score(D, km.labels_, metric='precomputed')) +print "Adjusted Mutual Information: %0.3f" % \ + metrics.adjusted_mutual_info_score(labels, km.labels_) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, km.labels_, metric='euclidean', sample_size=sample_size)) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index f5aaeea2a9547..c8123b2343573 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -16,4 +16,6 @@ from .cluster import completeness_score from .cluster import v_measure_score from .cluster import silhouette_score +from .cluster import mutual_info_score +from .cluster import adjusted_mutual_info_score from .pairwise import euclidean_distances, pairwise_distances, pairwise_kernels diff --git a/sklearn/metrics/cluster/__init__.py b/sklearn/metrics/cluster/__init__.py index 5450f1112bbeb..81d092039a49a 100644 --- a/sklearn/metrics/cluster/__init__.py +++ b/sklearn/metrics/cluster/__init__.py @@ -5,7 +5,10 @@ - supervised, which uses a ground truth class values for each sample. - unsupervised, which does not and measures the 'quality' of the model itself. """ -from .supervised import (homogeneity_completeness_v_measure, - homogeneity_score, completeness_score, - v_measure_score, adjusted_rand_score) -from .unsupervised import silhouette_score, silhouette_samples +from supervised import (homogeneity_completeness_v_measure, + homogeneity_score, completeness_score, + v_measure_score, adjusted_rand_score, + adjusted_mutual_info_score, + expected_mutual_information, mutual_info_score, + contingency_matrix) +from unsupervised import silhouette_score, silhouette_samples diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index ff57155a71944..f253a9a5f9085 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -9,6 +9,7 @@ from math import log from scipy.misc import comb +from scipy.special import gammaln import numpy as np @@ -38,6 +39,47 @@ def check_clusterings(labels_true, labels_pred): return labels_true, labels_pred +def contingency_matrix(labels_true, labels_pred, eps=None): + """Build a contengency matrix describing the relationship between labels. + + Parameters + ---------- + labels_true : int array, shape = [n_samples] + Ground truth class labels to be used as a reference + + labels_pred : array, shape = [n_samples] + Cluster labels to evaluate + + eps: None or float + If a float, that value is added to all values in the contingency + matrix. This helps to stop NaN propogation. + If None, nothing is adjusted. + + Returns + ------- + contingency: array, shape=[n_classes_true, n_classes_pred] + Matrix C such that C[i][j] is the number of samples in true class i and + in predicted class j. If eps is None, the dtype of this array will be + integer. If eps is given, the dtype will be float. + """ + classes = np.unique(labels_true) + clusters = np.unique(labels_pred) + # The cluster and class ids are not necessarily consecutive integers + # starting at 0 hence build a map + class_idx = dict((k, v) for v, k in enumerate(classes)) + cluster_idx = dict((k, v) for v, k in enumerate(clusters)) + # Build the contingency table + n_classes = classes.shape[0] + n_clusters = clusters.shape[0] + contingency = np.zeros((n_classes, n_clusters), dtype=np.int) + for c, k in zip(labels_true, labels_pred): + contingency[class_idx[c], cluster_idx[k]] += 1 + if eps is not None: + # Must be a float matrix to accept float eps + contingency = np.array(contingency, dtype='float') + eps + return contingency + + # clustering measures def adjusted_rand_score(labels_true, labels_pred): @@ -115,33 +157,20 @@ def adjusted_rand_score(labels_true, labels_pred): See also -------- - - ami_score: Adjusted Mutual Information (TODO: implement me!) + - adjusted_mutual_info_score: Adjusted Mutual Information """ labels_true, labels_pred = check_clusterings(labels_true, labels_pred) n_samples = labels_true.shape[0] - classes = np.unique(labels_true) clusters = np.unique(labels_pred) - # Special limit cases: no clustering since the data is not split. # This is a perfect match hence return 1.0. if (classes.shape[0] == clusters.shape[0] == 1 or classes.shape[0] == clusters.shape[0] == 0): return 1.0 - # The cluster and class ids are not necessarily consecutive integers - # starting at 0 hence build a map - class_idx = dict((k, v) for v, k in enumerate(classes)) - cluster_idx = dict((k, v) for v, k in enumerate(clusters)) - - # Build the contingency table - n_classes = classes.shape[0] - n_clusters = clusters.shape[0] - contingency = np.zeros((n_classes, n_clusters), dtype=np.int) - - for c, k in zip(labels_true, labels_pred): - contingency[class_idx[c], cluster_idx[k]] += 1 + contingency = contingency_matrix(labels_true, labels_pred) # Compute the ARI using the contingency data sum_comb_c = sum(comb2(n_c) for n_c in contingency.sum(axis=1)) @@ -458,3 +487,198 @@ def v_measure_score(labels_true, labels_pred): """ return homogeneity_completeness_v_measure(labels_true, labels_pred)[2] + + +def mutual_info_score(labels_true, labels_pred, contingency=None): + """Adjusted Mutual Information between two clusterings + + The Mutual Information is a measure of the similarity between two labels + of the same data. Where P(i) is the probability of a random sample occuring + in cluster U_i and P'(j) is the probability of a random sample occuring in + cluster V_j, the Mutual information between clusterings U and V is given + as: + + MI(U,V)=\sum_{i=1}^R \sum_{j=1}^C P(i,j)\log \frac{P(i,j)}{P(i)P'(j)} + + This metric is independent of the absolute values of the labels: + a permutation of the class or cluster label values won't change the + score value in any way. + + This metric is furthermore symmetric: switching `label_true` with + `label_pred` will return the same score value. This can be useful to + measure the agreement of two independent label assignments strategies + on the same dataset when the real ground truth is not known. + + Parameters + ---------- + labels_true : int array, shape = [n_samples] + A clustering of the data into disjoint subsets. + + labels_pred : array, shape = [n_samples] + A clustering of the data into disjoint subsets. + + contingency: None or array, shape = [n_classes_true, n_classes_pred] + A contingency matrix given by the contingency_matrix function. + If value is None, it will be computed, otherwise the given value is + used, with labels_true and labels_pred ignored. + + Returns + ------- + mi: float + Mutual information, a non-negative value + + See also + -------- + - adjusted_mutual_info_score: Adjusted Mutual Information + """ + if contingency is None: + labels_true, labels_pred = check_clusterings(labels_true, labels_pred) + contingency = contingency_matrix(labels_true, labels_pred) + contingency = np.array(contingency, dtype='float') + contingency /= np.sum(contingency) + pi = np.sum(contingency, axis=1) + pi /= np.sum(pi) + pj = np.sum(contingency, axis=0) + pj /= np.sum(pj) + outer = np.outer(pi, pj) + nnz = contingency != 0.0 + mi = contingency[nnz] * np.log(contingency[nnz] / outer[nnz]) + return mi.sum() + + +def adjusted_mutual_info_score(labels_true, labels_pred): + """Adjusted Mutual Information between two clusterings + + Adjusted Mutual Information (AMI) is an adjustement of the Mutual + Information (MI) score to account for chance. It accounts for the fact that + the MI is generally higher for two clusterings with a larger number of + clusters, regardless of whether there is actually more information shared. + For two clusterings U and V, the AMI is given as: + + AMI(U, V) = \frac{MI(U, V) - E(MI(U, V))}{max(H(U), H(V)) - E(MI(U, V))} + + This metric is independent of the absolute values of the labels: + a permutation of the class or cluster label values won't change the + score value in any way. + + This metric is furthermore symmetric: switching `label_true` with + `label_pred` will return the same score value. This can be useful to + measure the agreement of two independent label assignments strategies + on the same dataset when the real ground truth is not known. + + Be mindful that this function is an order of magnitude slower than other + metrics, such as the Adjusted Rand Index. + + Parameters + ---------- + labels_true : int array, shape = [n_samples] + A clustering of the data into disjoint subsets. + + labels_pred : array, shape = [n_samples] + A clustering of the data into disjoint subsets. + + Returns + ------- + ami: float + score between 0.0 and 1.0. 1.0 stands for perfectly complete labeling + + See also + -------- + - adjusted_rand_score: Adjusted Rand Index + - mutual_information_score: Mutual Information (not adjusted for chance) + + Examples + -------- + + Perfect labelings are both homogeneous and complete, hence have + score 1.0:: + + >>> from sklearn.metrics.cluster import adjusted_mutual_info_score + >>> adjusted_mutual_info_score([0, 0, 1, 1], [0, 0, 1, 1]) + 1.0 + >>> adjusted_mutual_info_score([0, 0, 1, 1], [1, 1, 0, 0]) + 1.0 + + If classes members are completly splitted accross different clusters, + the assignment is totally in-complete, hence the AMI is null:: + + >>> adjusted_mutual_info_score([0, 0, 0, 0], [0, 1, 2, 3]) + 0.0 + + """ + labels_true, labels_pred = check_clusterings(labels_true, labels_pred) + n_samples = labels_true.shape[0] + classes = np.unique(labels_true) + clusters = np.unique(labels_pred) + # Special limit cases: no clustering since the data is not split. + # This is a perfect match hence return 1.0. + if (classes.shape[0] == clusters.shape[0] == 1 + or classes.shape[0] == clusters.shape[0] == 0): + return 1.0 + contingency = contingency_matrix(labels_true, labels_pred) + contingency = np.array(contingency, dtype='float') + # Calculate the MI for the two clusterings + mi = mutual_info_score(labels_true, labels_pred, + contingency=contingency) + # Calculate the expected value for the mutual information + emi = expected_mutual_information(contingency, n_samples) + # Calculate entropy for each labeling + h_true, h_pred = entropy(labels_true), entropy(labels_pred) + ami = (mi - emi) / (max(h_true, h_pred) - emi) + return ami + + +def expected_mutual_information(contingency, n_samples): + """Calculate the expected mutual information for two labelings.""" + R, C = contingency.shape + N = float(n_samples) + a = np.sum(contingency, axis=1, dtype='int') + b = np.sum(contingency, axis=0, dtype='int') + # There are three major terms to the EMI equation, which are multiplied to + # and then summed over varying nij values. + # While nijs[0] will never be used, having it simplifies the indexing. + nijs = np.arange(0, max(np.max(a), np.max(b)) + 1, dtype='float') + nijs[0] = 1 # Stops divide by zero warnings. As its not used, no issue. + # term1 is nij / N + term1 = nijs / N + # term2 is log((N*nij) / (a * b)) == log(N * nij) - log(a * b) + # term2 uses the outer product + log_ab_outer = np.log(np.outer(a, b)) + # term2 uses N * nij + log_Nnij = np.log(N * nijs) + # term3 is large, and involved many factorials. Calculate these in log + # space to stop overflows. + gln_a = gammaln(a + 1) + gln_b = gammaln(b + 1) + gln_Na = gammaln(N - a + 1) + gln_Nb = gammaln(N - b + 1) + gln_N = gammaln(N + 1) + gln_nij = gammaln(nijs + 1) + # start and end values for nij terms for each summation. + start = np.array([[v - N + w for w in b] for v in a], dtype='int') + start = np.maximum(start, 1) + end = np.minimum(np.resize(a, (C, R)).T, np.resize(b, (R, C))) + 1 + # emi itself is a summation over the various values. + emi = 0 + for i in range(R): + for j in range(C): + for nij in range(start[i][j], end[i][j]): + term2 = log_Nnij[nij] - log_ab_outer[i][j] + # Numerators are positive, denominators are negative. + gln = (gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j] + - gln_N - gln_nij[nij] - gammaln(a[i] - nij + 1) + - gammaln(b[j] - nij + 1) + - gammaln(N - a[i] - b[j] + nij + 1)) + term3 = np.exp(gln) + # Add the product of all terms. + emi += (term1[nij] * term2 * term3) + return emi + + +def entropy(labels): + """Calculates the entropy for a labeling.""" + pi = np.array([np.sum(labels == i) for i in np.unique(labels)], + dtype='float') + pi = pi[pi > 0] + pi /= np.sum(pi) + return -np.sum(pi * np.log(pi)) diff --git a/sklearn/metrics/cluster/tests/test_cluster.py b/sklearn/metrics/cluster/tests/test_supervised.py similarity index 76% rename from sklearn/metrics/cluster/tests/test_cluster.py rename to sklearn/metrics/cluster/tests/test_supervised.py index 35a4982ca76f6..98399d9dc2288 100644 --- a/sklearn/metrics/cluster/tests/test_cluster.py +++ b/sklearn/metrics/cluster/tests/test_supervised.py @@ -5,6 +5,10 @@ from sklearn.metrics.cluster import completeness_score from sklearn.metrics.cluster import v_measure_score from sklearn.metrics.cluster import homogeneity_completeness_v_measure +from sklearn.metrics.cluster import adjusted_mutual_info_score +from sklearn.metrics.cluster import mutual_info_score +from sklearn.metrics.cluster import expected_mutual_information +from sklearn.metrics.cluster import contingency_matrix from nose.tools import assert_almost_equal from nose.tools import assert_equal @@ -16,6 +20,7 @@ homogeneity_score, completeness_score, v_measure_score, + adjusted_mutual_info_score, ] @@ -128,3 +133,28 @@ def test_adjustment_for_chance(): max_abs_scores = np.abs(scores).max(axis=1) assert_array_almost_equal(max_abs_scores, [0.02, 0.03, 0.03, 0.02], 2) + + +def test_adjusted_mutual_info_score(): + """Compute the Adjusted Mutual Information and test against known values""" + labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) + labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]) + # Mutual information + mi = mutual_info_score(labels_a, labels_b) + assert_almost_equal(mi, 0.41022, 5) + # Expected mutual information + C = contingency_matrix(labels_a, labels_b) + n_samples = np.sum(C) + emi = expected_mutual_information(C, n_samples) + assert_almost_equal(emi, 0.15042, 5) + # Adjusted mutual information + ami = adjusted_mutual_info_score(labels_a, labels_b) + assert_almost_equal(ami, 0.27502, 5) + ami = adjusted_mutual_info_score([1, 1, 2, 2], [2, 2, 3, 3]) + assert_equal(ami, 1.0) + # Test with a very large array + a110 = np.array([list(labels_a) * 110]).flatten() + b110 = np.array([list(labels_b) * 110]).flatten() + ami = adjusted_mutual_info_score(a110, b110) + # This is not accurate to more than 2 places + assert_almost_equal(ami, 0.37, 2)