From 2eb983342bfda3520aef3babfaddcfed7f0609f0 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Fri, 14 Oct 2011 14:34:43 +1100 Subject: [PATCH 01/30] Trying to fix NaN errors, but its not working. Pushing to work on it later. --- sklearn/metrics/cluster/__init__.py | 2 +- sklearn/metrics/cluster/supervised.py | 319 ++++++++++++++++++++++++-- 2 files changed, 305 insertions(+), 16 deletions(-) diff --git a/sklearn/metrics/cluster/__init__.py b/sklearn/metrics/cluster/__init__.py index 63178fe3b10ba..21a0ac1eafe08 100644 --- a/sklearn/metrics/cluster/__init__.py +++ b/sklearn/metrics/cluster/__init__.py @@ -7,5 +7,5 @@ """ from supervised import (homogeneity_completeness_v_measure, homogeneity_score, completeness_score, - v_measure_score, adjusted_rand_score) + v_measure_score, adjusted_rand_score, ami_score) from unsupervised import silhouette_score, silhouette_samples diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 31fe009324fe2..966524df0b186 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -8,7 +8,7 @@ # License: BSD Style. from math import log -from scipy.misc import comb +from scipy.misc import comb, factorial import numpy as np @@ -38,6 +38,51 @@ def check_clusterings(labels_true, labels_pred): return labels_true, labels_pred +def contingency_matrix(labels_true, labels_pred, eps=None): + """ Build a contengency matrix describing the relationship between labels. + + Parameters + ---------- + labels_true : int array, shape = [n_samples] + Ground truth class labels to be used as a reference + + labels_pred : array, shape = [n_samples] + Cluster labels to evaluate + + eps: None or float + If a float, that value is added to all values in the contingency matrix. + This helps to stop NaN propogation. + If None, nothing is adjusted. + + Returns + ------- + contingency: array, shape=[n_classes_true, n_classes_pred] + Matrix C such that C[i][j] is the number of samples in true class i and + in predicted class j. + """ + n_samples = labels_true.shape[0] + + classes = np.unique(labels_true) + clusters = np.unique(labels_pred) + + # The cluster and class ids are not necessarily consecutive integers + # starting at 0 hence build a map + class_idx = dict((k, v) for v, k in enumerate(classes)) + cluster_idx = dict((k, v) for v, k in enumerate(clusters)) + + # Build the contingency table + n_classes = classes.shape[0] + n_clusters = clusters.shape[0] + contingency = np.zeros((n_classes, n_clusters), dtype=np.int) + + for c, k in zip(labels_true, labels_pred): + contingency[class_idx[c], cluster_idx[k]] += 1 + if eps is not None: + # Must be a float matrix to accept float eps + contingency = np.array(contingency, dtype='float') + eps + return contingency + + # clustering measures def adjusted_rand_score(labels_true, labels_pred): @@ -120,28 +165,15 @@ def adjusted_rand_score(labels_true, labels_pred): """ labels_true, labels_pred = check_clusterings(labels_true, labels_pred) n_samples = labels_true.shape[0] - classes = np.unique(labels_true) clusters = np.unique(labels_pred) - # Special limit cases: no clustering since the data is not split. # This is a perfect match hence return 1.0. if (classes.shape[0] == clusters.shape[0] == 1 or classes.shape[0] == clusters.shape[0] == 0): return 1.0 - # The cluster and class ids are not necessarily consecutive integers - # starting at 0 hence build a map - class_idx = dict((k, v) for v, k in enumerate(classes)) - cluster_idx = dict((k, v) for v, k in enumerate(clusters)) - - # Build the contingency table - n_classes = classes.shape[0] - n_clusters = clusters.shape[0] - contingency = np.zeros((n_classes, n_clusters), dtype=np.int) - - for c, k in zip(labels_true, labels_pred): - contingency[class_idx[c], cluster_idx[k]] += 1 + contingency = contingency_matrix(labels_true, labels_pred) # Compute the ARI using the contingency data sum_comb_c = sum(comb2(n_c) for n_c in contingency.sum(axis=1)) @@ -458,3 +490,260 @@ def v_measure_score(labels_true, labels_pred): """ return homogeneity_completeness_v_measure(labels_true, labels_pred)[2] + +def v_measure_score(labels_true, labels_pred): + """V-Measure cluster labeling given a ground truth + + The V-Measure is the hormonic mean between homogeneity and completeness: + + v = 2 * (homogeneity * completeness) / (homogeneity + completeness) + + This metric is independent of the absolute values of the labels: + a permutation of the class or cluster label values won't change the + score value in any way. + + This metric is furthermore symmetric: switching `label_true` with + `label_pred` will return the same score value. This can be useful to + measure the agreement of two independent label assignments strategies + on the same dataset when the real ground truth is not known. + + Parameters + ---------- + labels_true : int array, shape = [n_samples] + ground truth class labels to be used as a reference + + labels_pred : array, shape = [n_samples] + cluster labels to evaluate + + Returns + ------- + completeness: float + score between 0.0 and 1.0. 1.0 stands for perfectly complete labeling + + References + ---------- + V-Measure: A conditional entropy-based external cluster evaluation measure + Andrew Rosenberg and Julia Hirschberg, 2007 + http://acl.ldc.upenn.edu/D/D07/D07-1043.pdf + + See also + -------- + - homogeneity_score + - completeness_score + + Examples + -------- + + Perfect labelings are both homogeneous and complete, hence have score 1.0:: + + >>> from sklearn.metrics.cluster import v_measure_score + >>> v_measure_score([0, 0, 1, 1], [0, 0, 1, 1]) + 1.0 + >>> v_measure_score([0, 0, 1, 1], [1, 1, 0, 0]) + 1.0 + + Labelings that assign all classes members to the same clusters + are complete be not homogeneous, hence penalized:: + + >>> v_measure_score([0, 0, 1, 2], [0, 0, 1, 1]) # doctest: +ELLIPSIS + 0.8... + >>> v_measure_score([0, 1, 2, 3], [0, 0, 1, 1]) # doctest: +ELLIPSIS + 0.66... + + Labelings that have pure clusters with members coming from the same + classes are homogeneous but un-necessary splits harms completeness + and thus penalize V-measure as well:: + + >>> v_measure_score([0, 0, 1, 1], [0, 0, 1, 2]) # doctest: +ELLIPSIS + 0.8... + >>> v_measure_score([0, 0, 1, 1], [0, 1, 2, 3]) # doctest: +ELLIPSIS + 0.66... + + If classes members are completly splitted accross different clusters, + the assignment is totally in-complete, hence the v-measure is null:: + + >>> v_measure_score([0, 0, 0, 0], [0, 1, 2, 3]) + 0.0 + + Clusters that include samples from totally different classes totally + destroy the homogeneity of the labeling, hence:: + + >>> v_measure_score([0, 0, 1, 1], [0, 0, 0, 0]) + 0.0 + + """ + return homogeneity_completeness_v_measure(labels_true, labels_pred)[2] + + +def mutual_information(labels_true, labels_pred, contingency=None): + """Adjusted Mutual Information between two clusterings + + The Mutual Information is a measure of the similarity between two labels + of the same data. Where P(i) is the probability of a random sample occuring + in cluster U_i and P'(j) is the probability of a random sample occuring in + cluster V_j, the Mutual information between clusterings U and V is given + as: + + MI(U,V)=\sum_{i=1}^R \sum_{j=1}^C P(i,j)\log \frac{P(i,j)}{P(i)P'(j)} + + This metric is independent of the absolute values of the labels: + a permutation of the class or cluster label values won't change the + score value in any way. + + This metric is furthermore symmetric: switching `label_true` with + `label_pred` will return the same score value. This can be useful to + measure the agreement of two independent label assignments strategies + on the same dataset when the real ground truth is not known. + + Parameters + ---------- + labels_true : int array, shape = [n_samples] + A clustering of the data into disjoint subsets. + + labels_pred : array, shape = [n_samples] + A clustering of the data into disjoint subsets. + + contingency: None or array, shape = [n_classes_true, n_classes_pred] + A contingency matrix given by the contingency_matrix function. + If value is None, it will be computed, otherwise the given value is + used, with labels_true and labels_pred ignored. + + Returns + ------- + mi: float + Mutual information, a non-negative value + """ + if contingency is None: + labels_true, labels_pred = check_clusterings(labels_true, labels_pred) + contingency = contingency_matrix(labels_true, labels_pred) + # Calculate P(i) for all i and P'(j) for all j + pi = np.sum(contingency, axis=1) + pi /= float(np.sum(pi)) + pj = np.sum(contingency, axis=0) + pj /= float(np.sum(pj)) + # Compute log for all values + log_pij = np.log(contingency) + # Product of pi and pj for denominator + pi_pj = np.outer(pi, pj) + # Remembering that log(x/y) = log(x) - log(y) + mi = np.sum(contingency * (log_pij - pi_pj)) + return mi + + +def ami_score(labels_true, labels_pred): + """Adjusted Mutual Information between two clusterings + + Adjusted Mutual Information (AMI) is an adjustement of the Mutual + Information (MI) score to account for chance. It accounts for the fact that + the MI is generally higher for two clusterings with a larger number of + clusters, regardless of whether there is actually more information shared. + For two clusterings U and V, the AMI is given as: + + AMI(U, V) = \frac{MI(U, V) - E(MI(U, V))}{max(H(U), H(V)) - E(MI(U, V))} + + This metric is independent of the absolute values of the labels: + a permutation of the class or cluster label values won't change the + score value in any way. + + This metric is furthermore symmetric: switching `label_true` with + `label_pred` will return the same score value. This can be useful to + measure the agreement of two independent label assignments strategies + on the same dataset when the real ground truth is not known. + + Parameters + ---------- + labels_true : int array, shape = [n_samples] + A clustering of the data into disjoint subsets. + + labels_pred : array, shape = [n_samples] + A clustering of the data into disjoint subsets. + + Returns + ------- + ami: float + score between 0.0 and 1.0. 1.0 stands for perfectly complete labeling + + + Examples + -------- + + Perfect labelings are both homogeneous and complete, hence have score 1.0:: + + >>> from sklearn.metrics.cluster import ami_score + >>> ami_score([0, 0, 1, 1], [0, 0, 1, 1]) + 1.0 + >>> v_measure_score([0, 0, 1, 1], [1, 1, 0, 0]) + 1.0 + + + If classes members are completly splitted accross different clusters, + the assignment is totally in-complete, hence the AMI is null:: + + >>> v_measure_score([0, 0, 0, 0], [0, 1, 2, 3]) + 0.0 + + """ + labels_true, labels_pred = check_clusterings(labels_true, labels_pred) + n_samples = labels_true.shape[0] + classes = np.unique(labels_true) + clusters = np.unique(labels_pred) + # Special limit cases: no clustering since the data is not split. + # This is a perfect match hence return 1.0. + if (classes.shape[0] == clusters.shape[0] == 1 + or classes.shape[0] == clusters.shape[0] == 0): + return 1.0 + eps = np.finfo(float).eps + eps = 1e-15 + contingency = contingency_matrix(labels_true, labels_pred, + eps=eps) + # Calculate the MI for the two clusterings + mi = mutual_information(labels_true, labels_pred, contingency=contingency) + assert not np.isnan(mi), "mutual information is nan. %r\n%r\n%r" % (labels_true, labels_pred, contingency) + # Calcualte the expected value for the mutual information + emi = _expected_mutual_information(contingency, n_samples) + assert not np.isnan(emi), "emi is nan" + # Calculate entropy for each labelling + h_true, h_pred = entropy(labels_true), entropy(labels_pred) + assert not np.isnan(h_true), "h_true is nan" + assert not np.isnan(h_pred), "h_pred is nan" + ami = (mi - emi) / (max(h_true, h_pred) - emi) + assert not np.isnan(ami), "%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % (mi, emi, h_true, h_pred, emi) + return ami + +def _expected_mutual_information(contingency, n_samples): + """ Calculate the expected mutual information for two labellings. """ + n_samples = float(n_samples) + M = np.zeros(contingency.shape, dtype='float') + R, C = contingency.shape + a = np.sum(contingency, axis=1) + b = np.sum(contingency, axis=0) + fact_N = factorial(n_samples) + fact_a = factorial(a) + fact_b = factorial(b) + fact_Na = factorial(n_samples - a) + fact_Nb = factorial(n_samples - b) + for i in range(R): + for j in range(C): + start = int(max(0, a[i] + b[j] - n_samples)) + end = int(min(a[i], b[j])) + if end == 0 or contingency[i][j] == 0: + continue + n1 = contingency[i][j] / n_samples + n2 = np.log(n_samples * contingency[i][j] / (a[i] * b[i])) + n3 = (fact_a[i] * fact_b[j] * fact_Na[i] * fact_Nb[j]) + numerator = n1 * n2 * n3 + assert not np.isnan(numerator), "%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % (n1, n2, n3, n_samples, contingency[i][j], a[i] * b[i]) + for nij in range(start, end): + d1 = float(fact_N * factorial(nij) * factorial(a[i] - nij) * + factorial(b[j] - nij) * + factorial(n_samples - a[i] - b[j] + nij)) + assert not np.isnan(numerator / d1) or not np.isinf(numerator / d1), "%.4f, %.2f" % (d1, nij) + M[i][j] += numerator / d1 + assert not np.isnan(np.sum(M)), M + return np.sum(M) + + +def entropy(labels): + """ Calculates the entropy for a labelling. """ + pi = np.array([np.sum(labels == i) for i in np.unique(labels)]) + return -np.sum(pi * np.log(pi)) From c3d3906a918d661c54918708135e2448c2fcbc87 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Sat, 15 Oct 2011 22:42:52 +1100 Subject: [PATCH 02/30] Mutual information now works (tested!) --- sklearn/metrics/__init__.py | 1 + sklearn/metrics/cluster/__init__.py | 3 +- sklearn/metrics/cluster/supervised.py | 52 +++++++++++++-------------- 3 files changed, 27 insertions(+), 29 deletions(-) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index f5aaeea2a9547..3fbc895f8968c 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -16,4 +16,5 @@ from .cluster import completeness_score from .cluster import v_measure_score from .cluster import silhouette_score +from .cluster import mutual_information from .pairwise import euclidean_distances, pairwise_distances, pairwise_kernels diff --git a/sklearn/metrics/cluster/__init__.py b/sklearn/metrics/cluster/__init__.py index 21a0ac1eafe08..e168eb889fa1e 100644 --- a/sklearn/metrics/cluster/__init__.py +++ b/sklearn/metrics/cluster/__init__.py @@ -7,5 +7,6 @@ """ from supervised import (homogeneity_completeness_v_measure, homogeneity_score, completeness_score, - v_measure_score, adjusted_rand_score, ami_score) + v_measure_score, adjusted_rand_score, ami_score, + mutual_information) from unsupervised import silhouette_score, silhouette_samples diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 966524df0b186..2ef20c635eb49 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -491,6 +491,7 @@ def v_measure_score(labels_true, labels_pred): """ return homogeneity_completeness_v_measure(labels_true, labels_pred)[2] + def v_measure_score(labels_true, labels_pred): """V-Measure cluster labeling given a ground truth @@ -616,18 +617,14 @@ def mutual_information(labels_true, labels_pred, contingency=None): if contingency is None: labels_true, labels_pred = check_clusterings(labels_true, labels_pred) contingency = contingency_matrix(labels_true, labels_pred) - # Calculate P(i) for all i and P'(j) for all j + contingency = np.array(contingency, dtype='float') + contingency /= np.sum(contingency) pi = np.sum(contingency, axis=1) - pi /= float(np.sum(pi)) + pi /= np.sum(pi) pj = np.sum(contingency, axis=0) - pj /= float(np.sum(pj)) - # Compute log for all values - log_pij = np.log(contingency) - # Product of pi and pj for denominator - pi_pj = np.outer(pi, pj) - # Remembering that log(x/y) = log(x) - log(y) - mi = np.sum(contingency * (log_pij - pi_pj)) - return mi + pj /= np.sum(pj) + mi = contingency * np.log2(contingency / np.outer(pi, pj)) + return np.sum(mi[np.isfinite(mi)]) def ami_score(labels_true, labels_pred): @@ -672,14 +669,14 @@ def ami_score(labels_true, labels_pred): >>> from sklearn.metrics.cluster import ami_score >>> ami_score([0, 0, 1, 1], [0, 0, 1, 1]) 1.0 - >>> v_measure_score([0, 0, 1, 1], [1, 1, 0, 0]) + >>> ami_score([0, 0, 1, 1], [1, 1, 0, 0]) 1.0 If classes members are completly splitted accross different clusters, the assignment is totally in-complete, hence the AMI is null:: - >>> v_measure_score([0, 0, 0, 0], [0, 1, 2, 3]) + >>> ami_score([0, 0, 0, 0], [0, 1, 2, 3]) 0.0 """ @@ -692,28 +689,26 @@ def ami_score(labels_true, labels_pred): if (classes.shape[0] == clusters.shape[0] == 1 or classes.shape[0] == clusters.shape[0] == 0): return 1.0 - eps = np.finfo(float).eps - eps = 1e-15 - contingency = contingency_matrix(labels_true, labels_pred, - eps=eps) + contingency = contingency_matrix(labels_true, labels_pred) + contingency = np.array(contingency, dtype='float') # Calculate the MI for the two clusterings mi = mutual_information(labels_true, labels_pred, contingency=contingency) - assert not np.isnan(mi), "mutual information is nan. %r\n%r\n%r" % (labels_true, labels_pred, contingency) # Calcualte the expected value for the mutual information emi = _expected_mutual_information(contingency, n_samples) - assert not np.isnan(emi), "emi is nan" + assert np.isfinite(emi), "emi is nan" # Calculate entropy for each labelling h_true, h_pred = entropy(labels_true), entropy(labels_pred) - assert not np.isnan(h_true), "h_true is nan" - assert not np.isnan(h_pred), "h_pred is nan" + assert np.isfinite(h_true), "h_true is nan" + assert np.isfinite(h_pred), "h_pred is nan" ami = (mi - emi) / (max(h_true, h_pred) - emi) - assert not np.isnan(ami), "%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % (mi, emi, h_true, h_pred, emi) + assert np.isfinite(ami), "%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % (mi, emi, h_true, h_pred, emi) return ami + def _expected_mutual_information(contingency, n_samples): """ Calculate the expected mutual information for two labellings. """ n_samples = float(n_samples) - M = np.zeros(contingency.shape, dtype='float') + M = 0. R, C = contingency.shape a = np.sum(contingency, axis=1) b = np.sum(contingency, axis=0) @@ -732,18 +727,19 @@ def _expected_mutual_information(contingency, n_samples): n2 = np.log(n_samples * contingency[i][j] / (a[i] * b[i])) n3 = (fact_a[i] * fact_b[j] * fact_Na[i] * fact_Nb[j]) numerator = n1 * n2 * n3 - assert not np.isnan(numerator), "%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % (n1, n2, n3, n_samples, contingency[i][j], a[i] * b[i]) + assert np.isfinite(numerator), "%r,%r,%r" % (n1, n2, n3) for nij in range(start, end): d1 = float(fact_N * factorial(nij) * factorial(a[i] - nij) * factorial(b[j] - nij) * factorial(n_samples - a[i] - b[j] + nij)) - assert not np.isnan(numerator / d1) or not np.isinf(numerator / d1), "%.4f, %.2f" % (d1, nij) - M[i][j] += numerator / d1 - assert not np.isnan(np.sum(M)), M - return np.sum(M) + M += numerator / d1 + return M def entropy(labels): """ Calculates the entropy for a labelling. """ - pi = np.array([np.sum(labels == i) for i in np.unique(labels)]) + pi = np.array([np.sum(labels == i) for i in np.unique(labels)], + dtype='float') + pi = pi[pi > 0] + pi /= np.sum(pi) return -np.sum(pi * np.log(pi)) From 9f05c114f8c9a5ec9f7973ffd54996795818b3f0 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Wed, 19 Oct 2011 21:25:09 +1100 Subject: [PATCH 03/30] AMI now works, and has been tested against the matlab code (test based on this to come!) --- sklearn/metrics/cluster/supervised.py | 65 +++++++++++++-------------- 1 file changed, 30 insertions(+), 35 deletions(-) diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 2ef20c635eb49..403a6ad4d354d 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -50,31 +50,27 @@ def contingency_matrix(labels_true, labels_pred, eps=None): Cluster labels to evaluate eps: None or float - If a float, that value is added to all values in the contingency matrix. - This helps to stop NaN propogation. + If a float, that value is added to all values in the contingency + matrix. This helps to stop NaN propogation. If None, nothing is adjusted. Returns ------- contingency: array, shape=[n_classes_true, n_classes_pred] Matrix C such that C[i][j] is the number of samples in true class i and - in predicted class j. + in predicted class j. If eps is None, the dtype of this array will be + integer. If eps is given, the dtype will be float. """ - n_samples = labels_true.shape[0] - classes = np.unique(labels_true) clusters = np.unique(labels_pred) - # The cluster and class ids are not necessarily consecutive integers # starting at 0 hence build a map class_idx = dict((k, v) for v, k in enumerate(classes)) cluster_idx = dict((k, v) for v, k in enumerate(clusters)) - # Build the contingency table n_classes = classes.shape[0] n_clusters = clusters.shape[0] contingency = np.zeros((n_classes, n_clusters), dtype=np.int) - for c, k in zip(labels_true, labels_pred): contingency[class_idx[c], cluster_idx[k]] += 1 if eps is not None: @@ -659,7 +655,6 @@ def ami_score(labels_true, labels_pred): ------- ami: float score between 0.0 and 1.0. 1.0 stands for perfectly complete labeling - Examples -------- @@ -695,45 +690,45 @@ def ami_score(labels_true, labels_pred): mi = mutual_information(labels_true, labels_pred, contingency=contingency) # Calcualte the expected value for the mutual information emi = _expected_mutual_information(contingency, n_samples) - assert np.isfinite(emi), "emi is nan" # Calculate entropy for each labelling h_true, h_pred = entropy(labels_true), entropy(labels_pred) - assert np.isfinite(h_true), "h_true is nan" - assert np.isfinite(h_pred), "h_pred is nan" ami = (mi - emi) / (max(h_true, h_pred) - emi) - assert np.isfinite(ami), "%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % (mi, emi, h_true, h_pred, emi) return ami def _expected_mutual_information(contingency, n_samples): """ Calculate the expected mutual information for two labellings. """ - n_samples = float(n_samples) - M = 0. R, C = contingency.shape + N = n_samples a = np.sum(contingency, axis=1) b = np.sum(contingency, axis=0) - fact_N = factorial(n_samples) - fact_a = factorial(a) - fact_b = factorial(b) - fact_Na = factorial(n_samples - a) - fact_Nb = factorial(n_samples - b) + factA = factorial(a) + factB = factorial(b) + factNA = factorial(N - a) + factNB = factorial(N - b) + factN = factorial(N) + emi = 0 for i in range(R): for j in range(C): - start = int(max(0, a[i] + b[j] - n_samples)) - end = int(min(a[i], b[j])) - if end == 0 or contingency[i][j] == 0: - continue - n1 = contingency[i][j] / n_samples - n2 = np.log(n_samples * contingency[i][j] / (a[i] * b[i])) - n3 = (fact_a[i] * fact_b[j] * fact_Na[i] * fact_Nb[j]) - numerator = n1 * n2 * n3 - assert np.isfinite(numerator), "%r,%r,%r" % (n1, n2, n3) + # numerator of the third term + num3 = factA[i] * factB[j] * factNA[i] * factNB[j] + assert np.isfinite(num3) + start = int(max(a[i] + b[j] - N, 1)) + end = int(min(a[i], b[j]) + 1) for nij in range(start, end): - d1 = float(fact_N * factorial(nij) * factorial(a[i] - nij) * - factorial(b[j] - nij) * - factorial(n_samples - a[i] - b[j] + nij)) - M += numerator / d1 - return M + nij = float(nij) + # term 1: nij / N + term1 = nij / N + assert np.isfinite(term1) + # term 2: log(N.nij / ai.bj) + term2 = np.log((N * nij) / (a[i] * b[j])) + assert np.isfinite(term2) + # denominator of term 3 + den3 = float(factN * factorial(nij) * factorial(a[i] - nij) + * factorial(b[j] - nij) + * factorial(N - a[i] - b[j] + nij)) + emi += term1 * term2 * (num3 / den3) + return emi def entropy(labels): @@ -742,4 +737,4 @@ def entropy(labels): dtype='float') pi = pi[pi > 0] pi /= np.sum(pi) - return -np.sum(pi * np.log(pi)) + return -np.sum(pi * np.log2(pi)) From e473f70be46e583f3bb6f50d72a7f740c7b0ea33 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Wed, 19 Oct 2011 21:46:07 +1100 Subject: [PATCH 04/30] Remove phantom double v-measure !? --- sklearn/metrics/cluster/supervised.py | 84 --------------------------- 1 file changed, 84 deletions(-) diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 403a6ad4d354d..17a23b871ba47 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -488,90 +488,6 @@ def v_measure_score(labels_true, labels_pred): return homogeneity_completeness_v_measure(labels_true, labels_pred)[2] -def v_measure_score(labels_true, labels_pred): - """V-Measure cluster labeling given a ground truth - - The V-Measure is the hormonic mean between homogeneity and completeness: - - v = 2 * (homogeneity * completeness) / (homogeneity + completeness) - - This metric is independent of the absolute values of the labels: - a permutation of the class or cluster label values won't change the - score value in any way. - - This metric is furthermore symmetric: switching `label_true` with - `label_pred` will return the same score value. This can be useful to - measure the agreement of two independent label assignments strategies - on the same dataset when the real ground truth is not known. - - Parameters - ---------- - labels_true : int array, shape = [n_samples] - ground truth class labels to be used as a reference - - labels_pred : array, shape = [n_samples] - cluster labels to evaluate - - Returns - ------- - completeness: float - score between 0.0 and 1.0. 1.0 stands for perfectly complete labeling - - References - ---------- - V-Measure: A conditional entropy-based external cluster evaluation measure - Andrew Rosenberg and Julia Hirschberg, 2007 - http://acl.ldc.upenn.edu/D/D07/D07-1043.pdf - - See also - -------- - - homogeneity_score - - completeness_score - - Examples - -------- - - Perfect labelings are both homogeneous and complete, hence have score 1.0:: - - >>> from sklearn.metrics.cluster import v_measure_score - >>> v_measure_score([0, 0, 1, 1], [0, 0, 1, 1]) - 1.0 - >>> v_measure_score([0, 0, 1, 1], [1, 1, 0, 0]) - 1.0 - - Labelings that assign all classes members to the same clusters - are complete be not homogeneous, hence penalized:: - - >>> v_measure_score([0, 0, 1, 2], [0, 0, 1, 1]) # doctest: +ELLIPSIS - 0.8... - >>> v_measure_score([0, 1, 2, 3], [0, 0, 1, 1]) # doctest: +ELLIPSIS - 0.66... - - Labelings that have pure clusters with members coming from the same - classes are homogeneous but un-necessary splits harms completeness - and thus penalize V-measure as well:: - - >>> v_measure_score([0, 0, 1, 1], [0, 0, 1, 2]) # doctest: +ELLIPSIS - 0.8... - >>> v_measure_score([0, 0, 1, 1], [0, 1, 2, 3]) # doctest: +ELLIPSIS - 0.66... - - If classes members are completly splitted accross different clusters, - the assignment is totally in-complete, hence the v-measure is null:: - - >>> v_measure_score([0, 0, 0, 0], [0, 1, 2, 3]) - 0.0 - - Clusters that include samples from totally different classes totally - destroy the homogeneity of the labeling, hence:: - - >>> v_measure_score([0, 0, 1, 1], [0, 0, 0, 0]) - 0.0 - - """ - return homogeneity_completeness_v_measure(labels_true, labels_pred)[2] - - def mutual_information(labels_true, labels_pred, contingency=None): """Adjusted Mutual Information between two clusterings From 2d8677f953bcad53a51c5ee72b61ff4a71897aad Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Wed, 19 Oct 2011 22:33:13 +1100 Subject: [PATCH 05/30] Added tests. There are two errors, but I'm going to bed. I'll fix them in the morning. --- sklearn/metrics/cluster/__init__.py | 5 ++-- sklearn/metrics/cluster/supervised.py | 4 ++-- .../{test_cluster.py => test_supervised.py} | 24 +++++++++++++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) rename sklearn/metrics/cluster/tests/{test_cluster.py => test_supervised.py} (81%) diff --git a/sklearn/metrics/cluster/__init__.py b/sklearn/metrics/cluster/__init__.py index e168eb889fa1e..0967583ab6fab 100644 --- a/sklearn/metrics/cluster/__init__.py +++ b/sklearn/metrics/cluster/__init__.py @@ -7,6 +7,7 @@ """ from supervised import (homogeneity_completeness_v_measure, homogeneity_score, completeness_score, - v_measure_score, adjusted_rand_score, ami_score, - mutual_information) + v_measure_score, adjusted_rand_score, + ami_score, expected_mutual_information, + mutual_information, contingency_matrix) from unsupervised import silhouette_score, silhouette_samples diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 17a23b871ba47..eb18a0a63ba7b 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -605,14 +605,14 @@ def ami_score(labels_true, labels_pred): # Calculate the MI for the two clusterings mi = mutual_information(labels_true, labels_pred, contingency=contingency) # Calcualte the expected value for the mutual information - emi = _expected_mutual_information(contingency, n_samples) + emi = expected_mutual_information(contingency, n_samples) # Calculate entropy for each labelling h_true, h_pred = entropy(labels_true), entropy(labels_pred) ami = (mi - emi) / (max(h_true, h_pred) - emi) return ami -def _expected_mutual_information(contingency, n_samples): +def expected_mutual_information(contingency, n_samples): """ Calculate the expected mutual information for two labellings. """ R, C = contingency.shape N = n_samples diff --git a/sklearn/metrics/cluster/tests/test_cluster.py b/sklearn/metrics/cluster/tests/test_supervised.py similarity index 81% rename from sklearn/metrics/cluster/tests/test_cluster.py rename to sklearn/metrics/cluster/tests/test_supervised.py index 35a4982ca76f6..1c0c197b016c7 100644 --- a/sklearn/metrics/cluster/tests/test_cluster.py +++ b/sklearn/metrics/cluster/tests/test_supervised.py @@ -5,6 +5,10 @@ from sklearn.metrics.cluster import completeness_score from sklearn.metrics.cluster import v_measure_score from sklearn.metrics.cluster import homogeneity_completeness_v_measure +from sklearn.metrics.cluster import ami_score +from sklearn.metrics.cluster import mutual_information +from sklearn.metrics.cluster import expected_mutual_information +from sklearn.metrics.cluster import contingency_matrix from nose.tools import assert_almost_equal from nose.tools import assert_equal @@ -16,6 +20,7 @@ homogeneity_score, completeness_score, v_measure_score, + ami_score, ] @@ -128,3 +133,22 @@ def test_adjustment_for_chance(): max_abs_scores = np.abs(scores).max(axis=1) assert_array_almost_equal(max_abs_scores, [0.02, 0.03, 0.03, 0.02], 2) + + +def test_ami_score(): + """Compute the Adjusted Mutual Information and test against known values""" + labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) + labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]) + # Mutual information + mi = mutual_information(labels_a, labels_b) + assert_almost_equal(mi, 0.59182, 5) + # Expected mutual information + C = contingency_matrix(labels_a, labels_b) + n_samples = np.sum(C) + emi = expected_mutual_information(C, n_samples) + assert_almost_equal(emi, 0.15042, 5) + # Adjusted mutual information + ami = ami_score(labels_a, labels_b) + assert_almost_equal(ami, 0.27502, 5) + ami = ami_score([1, 1, 2, 2], [2, 2, 3, 3]) + assert_equal(ami, 1.0) From 193ac024a4bd1125ea132bb4c0c2c0a6cbf2680c Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Fri, 21 Oct 2011 10:52:42 +1100 Subject: [PATCH 06/30] - AMI in the cluster examples - See Also sections updated - mutual_information -> mutual_information_score (and updating subsequent imports) --- .../plot_adjusted_for_chance_measures.py | 1 + examples/cluster/plot_affinity_propagation.py | 2 ++ examples/cluster/plot_dbscan.py | 2 ++ examples/cluster/plot_kmeans_digits.py | 4 ++-- sklearn/metrics/__init__.py | 2 +- sklearn/metrics/cluster/__init__.py | 2 +- sklearn/metrics/cluster/supervised.py | 24 +++++++++++++------ .../metrics/cluster/tests/test_supervised.py | 4 ++-- 8 files changed, 28 insertions(+), 13 deletions(-) diff --git a/examples/cluster/plot_adjusted_for_chance_measures.py b/examples/cluster/plot_adjusted_for_chance_measures.py index bed0dcc8758a1..8f05747d0fd9c 100644 --- a/examples/cluster/plot_adjusted_for_chance_measures.py +++ b/examples/cluster/plot_adjusted_for_chance_measures.py @@ -58,6 +58,7 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, score_funcs = [ metrics.adjusted_rand_score, metrics.v_measure_score, + metrics.ami_score ] # 2 independent random clusterings with equal cluster number diff --git a/examples/cluster/plot_affinity_propagation.py b/examples/cluster/plot_affinity_propagation.py index 73e8dd0271e3c..9aedba4e3d970 100644 --- a/examples/cluster/plot_affinity_propagation.py +++ b/examples/cluster/plot_affinity_propagation.py @@ -40,6 +40,8 @@ print "V-measure: %0.3f" % metrics.v_measure_score(labels_true, labels) print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels_true, labels) +print "Adjusted Mutual Information: %0.3f" % \ + metrics.ami_score(labels_true, labels) D = (S / np.min(S)) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(D, labels, metric='precomputed')) diff --git a/examples/cluster/plot_dbscan.py b/examples/cluster/plot_dbscan.py index 809965d6f21fc..8b12f0ccd6f11 100644 --- a/examples/cluster/plot_dbscan.py +++ b/examples/cluster/plot_dbscan.py @@ -41,6 +41,8 @@ print "V-measure: %0.3f" % metrics.v_measure_score(labels_true, labels) print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels_true, labels) +print "Adjusted Mutual Information: %0.3f" % \ + metrics.ami_score(labels_true, labels) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(D, labels, metric='precomputed')) diff --git a/examples/cluster/plot_kmeans_digits.py b/examples/cluster/plot_kmeans_digits.py index c1982ce268fc7..0ccc3b9e230ee 100644 --- a/examples/cluster/plot_kmeans_digits.py +++ b/examples/cluster/plot_kmeans_digits.py @@ -48,8 +48,8 @@ print "V-measure: %0.3f" % metrics.v_measure_score(labels, km.labels_) print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels, km.labels_) -#print ("Silhouette Coefficient: %0.3f" % -# metrics.silhouette_score(D, km.labels_, metric='precomputed')) +print "Adjusted Mutual Information: %0.3f" % \ + metrics.ami_score(labels_true, labels) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, km.labels_, metric='euclidean', sample_size=sample_size)) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 3fbc895f8968c..f8f8931cc3d5f 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -16,5 +16,5 @@ from .cluster import completeness_score from .cluster import v_measure_score from .cluster import silhouette_score -from .cluster import mutual_information +from .cluster import mutual_information_score from .pairwise import euclidean_distances, pairwise_distances, pairwise_kernels diff --git a/sklearn/metrics/cluster/__init__.py b/sklearn/metrics/cluster/__init__.py index 0bf0af21a5614..5f42b3a011787 100644 --- a/sklearn/metrics/cluster/__init__.py +++ b/sklearn/metrics/cluster/__init__.py @@ -9,6 +9,6 @@ homogeneity_score, completeness_score, v_measure_score, adjusted_rand_score, ami_score, expected_mutual_information, - mutual_information, contingency_matrix) + mutual_information_score, contingency_matrix) from unsupervised import silhouette_score, silhouette_samples diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index eb18a0a63ba7b..2021205c49590 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -39,7 +39,7 @@ def check_clusterings(labels_true, labels_pred): def contingency_matrix(labels_true, labels_pred, eps=None): - """ Build a contengency matrix describing the relationship between labels. + """Build a contengency matrix describing the relationship between labels. Parameters ---------- @@ -156,7 +156,7 @@ def adjusted_rand_score(labels_true, labels_pred): See also -------- - - ami_score: Adjusted Mutual Information (TODO: implement me!) + - ami_score: Adjusted Mutual Information """ labels_true, labels_pred = check_clusterings(labels_true, labels_pred) @@ -488,7 +488,7 @@ def v_measure_score(labels_true, labels_pred): return homogeneity_completeness_v_measure(labels_true, labels_pred)[2] -def mutual_information(labels_true, labels_pred, contingency=None): +def mutual_information_score(labels_true, labels_pred, contingency=None): """Adjusted Mutual Information between two clusterings The Mutual Information is a measure of the similarity between two labels @@ -525,6 +525,10 @@ def mutual_information(labels_true, labels_pred, contingency=None): ------- mi: float Mutual information, a non-negative value + + See also + -------- + - ami_score: Adjusted Mutual Information """ if contingency is None: labels_true, labels_pred = check_clusterings(labels_true, labels_pred) @@ -572,6 +576,11 @@ def ami_score(labels_true, labels_pred): ami: float score between 0.0 and 1.0. 1.0 stands for perfectly complete labeling + See also + -------- + - adjusted_rand_score: Adjusted Rand Index + - mutual_information_score: Mutual Information (not adjusted for chance) + Examples -------- @@ -603,17 +612,18 @@ def ami_score(labels_true, labels_pred): contingency = contingency_matrix(labels_true, labels_pred) contingency = np.array(contingency, dtype='float') # Calculate the MI for the two clusterings - mi = mutual_information(labels_true, labels_pred, contingency=contingency) + mi = mutual_information_score(labels_true, labels_pred, + contingency=contingency) # Calcualte the expected value for the mutual information emi = expected_mutual_information(contingency, n_samples) - # Calculate entropy for each labelling + # Calculate entropy for each labeling h_true, h_pred = entropy(labels_true), entropy(labels_pred) ami = (mi - emi) / (max(h_true, h_pred) - emi) return ami def expected_mutual_information(contingency, n_samples): - """ Calculate the expected mutual information for two labellings. """ + """Calculate the expected mutual information for two labelings.""" R, C = contingency.shape N = n_samples a = np.sum(contingency, axis=1) @@ -648,7 +658,7 @@ def expected_mutual_information(contingency, n_samples): def entropy(labels): - """ Calculates the entropy for a labelling. """ + """Calculates the entropy for a labeling.""" pi = np.array([np.sum(labels == i) for i in np.unique(labels)], dtype='float') pi = pi[pi > 0] diff --git a/sklearn/metrics/cluster/tests/test_supervised.py b/sklearn/metrics/cluster/tests/test_supervised.py index 1c0c197b016c7..d470b102bc5df 100644 --- a/sklearn/metrics/cluster/tests/test_supervised.py +++ b/sklearn/metrics/cluster/tests/test_supervised.py @@ -6,7 +6,7 @@ from sklearn.metrics.cluster import v_measure_score from sklearn.metrics.cluster import homogeneity_completeness_v_measure from sklearn.metrics.cluster import ami_score -from sklearn.metrics.cluster import mutual_information +from sklearn.metrics.cluster import mutual_information_score from sklearn.metrics.cluster import expected_mutual_information from sklearn.metrics.cluster import contingency_matrix @@ -140,7 +140,7 @@ def test_ami_score(): labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]) # Mutual information - mi = mutual_information(labels_a, labels_b) + mi = mutual_information_score(labels_a, labels_b) assert_almost_equal(mi, 0.59182, 5) # Expected mutual information C = contingency_matrix(labels_a, labels_b) From 867ec2fc53b2ebb1cc7d06b9046ed72fd1a12b46 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Fri, 21 Oct 2011 10:56:12 +1100 Subject: [PATCH 07/30] Higher level import for ami_score --- sklearn/metrics/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index f8f8931cc3d5f..cfd56dc57ba19 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -17,4 +17,5 @@ from .cluster import v_measure_score from .cluster import silhouette_score from .cluster import mutual_information_score +from .cluster import ami_score from .pairwise import euclidean_distances, pairwise_distances, pairwise_kernels From 18b9a7883b40c3f7640c734abbcd8744c01f69ea Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Fri, 21 Oct 2011 10:58:32 +1100 Subject: [PATCH 08/30] There is an overflow problem. It can be reproduced with the plot_adjusted_for_chance_measures.py example Commiting to get help on error --- sklearn/metrics/cluster/supervised.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 2021205c49590..203a519054475 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -638,7 +638,8 @@ def expected_mutual_information(contingency, n_samples): for j in range(C): # numerator of the third term num3 = factA[i] * factB[j] * factNA[i] * factNB[j] - assert np.isfinite(num3) + assert np.isfinite(num3), "%r,%r,%r,%r" % ( + factA[i], factB[j], factNA[i], factNB[j]) start = int(max(a[i] + b[j] - N, 1)) end = int(min(a[i], b[j]) + 1) for nij in range(start, end): From aaf5c23959c272651a5f58c0c9c61b07573facbe Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Fri, 21 Oct 2011 13:26:37 +1100 Subject: [PATCH 09/30] Narrative doc, and I think I fixed the overflow issue (more tests to come) --- doc/modules/clustering.rst | 137 ++++++++++++++++++++++++++ sklearn/metrics/cluster/supervised.py | 64 +++++++----- 2 files changed, 177 insertions(+), 24 deletions(-) diff --git a/doc/modules/clustering.rst b/doc/modules/clustering.rst index 41f904e1a4b9d..e1e6fc4e4102a 100644 --- a/doc/modules/clustering.rst +++ b/doc/modules/clustering.rst @@ -472,6 +472,143 @@ by defining the adjusted Rand index as follows: `_ +Ajusted Mutual Information +-------------------------- + +Presentation and usage +~~~~~~~~~~~~~~~~~~~~~~ + +Given the knowledge of the ground truth class assignments ``labels_true`` +and our clustering algorithm assignments of the same samples +``labels_pred``, the **Adjusted Mutual Information** is a function that +measures the **agreement** of the two assignements, ignoring permutations + and **with chance normalization**:: + + >>> from sklearn import metrics + >>> labels_true = [0, 0, 0, 1, 1, 1] + >>> labels_pred = [0, 0, 1, 1, 2, 2] + + >>> metrics.ami_score(labels_true, labels_pred) # doctest: +ELLIPSIS + 0.24... + +One can permute 0 and 1 in the predicted labels and rename `2` by `3` and get +the same score:: + + >>> labels_pred = [1, 1, 0, 0, 3, 3] + >>> metrics.ami_score(labels_true, labels_pred) # doctest: +ELLIPSIS + 0.24... + +Furthermore, :func:`ami_score` is **symmetric**: swapping the argument +does not change the score. It can thus be used as a **consensus +measure**:: + + >>> metrics.ami_score(labels_pred, labels_true) # doctest: +ELLIPSIS + 0.24... + +Perfect labeling is scored 1.0:: + + >>> labels_pred = labels_true[:] + >>> metrics.ami_score(labels_true, labels_pred) + 1.0 + +Bad (e.g. independent labelings) have scores of zero:: + + >>> labels_true = [0, 1, 2, 0, 3, 4, 5, 1] + >>> labels_pred = [1, 1, 0, 0, 2, 2, 2, 2] + >>> metrics.ami_score(labels_true, labels_pred) # doctest: +ELLIPSIS + 0.0... + + +Advantages +~~~~~~~~~~ + +- **Random (uniform) label assignements have a AMI score close to 0.0** + for any value of ``n_clusters`` and ``n_samples`` (which is not the + case for raw Mutual Information or the V-measure for instance). + +- **Bounded range [0, 1]**: Values close to zero indicate two label + assignments that are largely independent, while values close to one + indicate significant agreement. Further, values of exactly 0 indicate + **purely** independent label assignments and a AMI of exactly 1 indicates + that the two label assignments are equal (with or without permutation). + +- **No assumption is made on the cluster structure**: can be used + to compare clustering algorithms such as k-means which assumes isotropic + blob shapes with results of spectral clustering algorithms which can + find cluster with "folded" shapes. + + +Drawbacks +~~~~~~~~~ + +- Contrary to inertia, **AMI requires the knowlege of the ground truth + classes** while almost never available in practice or requires manual + assignment by human annotators (as in the supervised learning setting). + + However AMI can also be useful in purely unsupervised setting as a + building block for a Consensus Index that can be used for clustering + model selection. + + +.. topic:: Examples: + + * :ref:`example_cluster_plot_adjusted_for_chance_measures.py`: Analysis of + the impact of the dataset size on the value of clustering measures + for random assignements. This example also includes the Adjusted Rand + Index. + + +Mathematical formulation +~~~~~~~~~~~~~~~~~~~~~~~~ +Assume two label assignments (of the same data), :math:`U` with :math:`R` +classes and :math:`V` with :math:`C` classes. The entropy of either is the + amount of uncertaintly for an array, and can be calculated as: + +.. math:: H(U) = \sum_{i=1}^{|R|}P(i)log(P(i)) + +Where P(i) is the number of instances in U that are in class :math:`R_i`. +Likewise, for :math:`V`: +.. math:: H(V) = \sum_{j=1}^{|C|}P'(j)log(P'(j)) +Where P'(j) is the number of instances in V that are in class :math:`C_j`. + +The (non-adjusted) mutual information between :math:`U` and :math:`V` is +calculated by: + +.. math:: MI(U, V) = \sum_{i=1}^{|R|}\sum_{j=1}^{|C|}P(i, j)log(\frac{P(i,j)}{P(i)P'(j)}) + +Where P(i, j) is the number of instances with label :math:`R_i` +and also with label :math:`C_j`. + +This value of the mutual information is not adjusted cfor chance and will tend +to increase as the number of different labels (clusters) increases, regardless +of the actual amount of "mutual information" between the label assignments. + +The expected value for the mutual information can be calculated using the +following equation, from Vinh, Epps, and Bailey, (2009). In this equation, +:math:`a_i` is the number of instances with label :math:`U_i` and +:math:`b_j` is the number of instances with label :math:`V_j`. + + +.. math:: E\{MI(U,V)\}=\sum_{i=1}^R \sum_{j=1}^C \sum_{n_{ij}=(a_i+b_j-N)^+ + }^{\min(a_i, b_j)} \frac{n_{ij}}{N}\log ( \frac{ N.n_{ij}}{a_i b_j}) + \frac{a_i!b_j!(N-a_i)!(N-b_j)!}{N!n_{ij}!(a_i-n_{ij})!(b_j-n_{ij})! + (N-a_i-b_j+n_{ij})!} + +Using the expected value, the adjusted mutual information can then be +calculated using a similar form to that of the adjusted Rand index: + +.. math:: AMI = \frac{MI - Expected\_MI}{max(H(U), H(V)) - Expected\_MI} + +.. topic:: References + + * Vinh, Epps, and Bailey, J. (2009). "Information theoretic measures + for clusterings comparison". Proceedings of the 26th Annual International + Conference on Machine Learning - ICML '09. + doi:10.1145/1553374.1553511. ISBN 9781605585161. + + * `Wikipedia entry for the Adjusted Mutual Information + `_ + Homogeneity, completeness and V-measure --------------------------------------- diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 203a519054475..2114badda415a 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -614,7 +614,7 @@ def ami_score(labels_true, labels_pred): # Calculate the MI for the two clusterings mi = mutual_information_score(labels_true, labels_pred, contingency=contingency) - # Calcualte the expected value for the mutual information + # Calculate the expected value for the mutual information emi = expected_mutual_information(contingency, n_samples) # Calculate entropy for each labeling h_true, h_pred = entropy(labels_true), entropy(labels_pred) @@ -626,35 +626,51 @@ def expected_mutual_information(contingency, n_samples): """Calculate the expected mutual information for two labelings.""" R, C = contingency.shape N = n_samples - a = np.sum(contingency, axis=1) - b = np.sum(contingency, axis=0) - factA = factorial(a) - factB = factorial(b) - factNA = factorial(N - a) - factNB = factorial(N - b) - factN = factorial(N) + a = np.sum(contingency, axis=1, dtype='int') + b = np.sum(contingency, axis=0, dtype='int') emi = 0 for i in range(R): for j in range(C): - # numerator of the third term - num3 = factA[i] * factB[j] * factNA[i] * factNB[j] - assert np.isfinite(num3), "%r,%r,%r,%r" % ( - factA[i], factB[j], factNA[i], factNB[j]) start = int(max(a[i] + b[j] - N, 1)) end = int(min(a[i], b[j]) + 1) for nij in range(start, end): - nij = float(nij) - # term 1: nij / N - term1 = nij / N - assert np.isfinite(term1) - # term 2: log(N.nij / ai.bj) - term2 = np.log((N * nij) / (a[i] * b[j])) - assert np.isfinite(term2) - # denominator of term 3 - den3 = float(factN * factorial(nij) * factorial(a[i] - nij) - * factorial(b[j] - nij) - * factorial(N - a[i] - b[j] + nij)) - emi += term1 * term2 * (num3 / den3) + term1 = nij / float(N) + assert np.isfinite(term1), ("term1: %r, nij=%d, N=%d" % + (term1, nij, N)) + term2 = np.log(float(N * nij) / (a[i] * b[j])) + assert np.isfinite(term2), ("term2: %r, nij=%d, N=%d, a[i]=%d, b[j]=%d" % + (term2, nij, N, a[i], b[j])) + + # a! / (a - n)! + term3a = np.multiply.reduce(range(a[i] - nij + 1, a[i] + 1)) + assert np.isfinite(term3a), ("term3a: %r, a[i]=%d, nij=%d" % + (term3a, a[i], nij)) + # b! / (b - n)! + term3b = np.multiply.reduce(range(b[j] - nij + 1, b[j] + 1)) + assert np.isfinite(term3b), ("term3b: %r, b[j]=%d, nij=%d" % + (term3b, b[j], nij)) + # (N - a)! / N! + t = np.multiply.reduce(range(N - a[i] + 1, N + 1)) + if t == 0: + continue + term3c = 1. / t + assert np.isfinite(term3c), ("term3c: %r, a[i]=%d, N=%d, t=%.3f" % + (term3c, a[i], N, t)) + # (N - b)! / (N - a - b - n)! + num3d = N - b[j] + 1 + den3d = N - a[i] - b[j] + nij + 1 + if num3d > den3d: + term3d = np.multiply.reduce(range(den3d, num3d)) + else: + term3d = np.multiply.reduce(range(num3d, den3d)) + term3d = 1. / term3d + assert np.isfinite(term3d), ("term3d: %r, N=%d, a[i]=%d, b[j]=%d, nij=%d" % + (term3d, N, a[i], b[j], nij)) + # 1 / n! + term3e = 1. / factorial(nij) + # Add the product of all terms + emi += (term1 * term2 * term3a * term3b + * term3c * term3d * term3e) return emi From fb0c3f8fe898bfbf03f59870c6fa30eabbe06b8a Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Fri, 21 Oct 2011 19:36:56 +1100 Subject: [PATCH 10/30] Fixed logs to match the matlab code results. This fixes the tests (which needed a bit of updating) --- sklearn/metrics/cluster/supervised.py | 4 ++-- sklearn/metrics/cluster/tests/test_supervised.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 2114badda415a..2580225d90b8b 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -539,7 +539,7 @@ def mutual_information_score(labels_true, labels_pred, contingency=None): pi /= np.sum(pi) pj = np.sum(contingency, axis=0) pj /= np.sum(pj) - mi = contingency * np.log2(contingency / np.outer(pi, pj)) + mi = contingency * np.log(contingency / np.outer(pi, pj)) return np.sum(mi[np.isfinite(mi)]) @@ -680,4 +680,4 @@ def entropy(labels): dtype='float') pi = pi[pi > 0] pi /= np.sum(pi) - return -np.sum(pi * np.log2(pi)) + return -np.sum(pi * np.log(pi)) diff --git a/sklearn/metrics/cluster/tests/test_supervised.py b/sklearn/metrics/cluster/tests/test_supervised.py index d470b102bc5df..35298ee019553 100644 --- a/sklearn/metrics/cluster/tests/test_supervised.py +++ b/sklearn/metrics/cluster/tests/test_supervised.py @@ -141,7 +141,7 @@ def test_ami_score(): labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]) # Mutual information mi = mutual_information_score(labels_a, labels_b) - assert_almost_equal(mi, 0.59182, 5) + assert_almost_equal(mi, 0.41022, 5) # Expected mutual information C = contingency_matrix(labels_a, labels_b) n_samples = np.sum(C) From db845fa926ba801bdb09da21732cf6f6ee0f0d47 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Fri, 21 Oct 2011 19:59:20 +1100 Subject: [PATCH 11/30] Test now tests a much larger array --- sklearn/metrics/cluster/tests/test_supervised.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sklearn/metrics/cluster/tests/test_supervised.py b/sklearn/metrics/cluster/tests/test_supervised.py index 35298ee019553..0ede7dd9f8a1f 100644 --- a/sklearn/metrics/cluster/tests/test_supervised.py +++ b/sklearn/metrics/cluster/tests/test_supervised.py @@ -152,3 +152,8 @@ def test_ami_score(): assert_almost_equal(ami, 0.27502, 5) ami = ami_score([1, 1, 2, 2], [2, 2, 3, 3]) assert_equal(ami, 1.0) + # Test with a very large array + a100 = np.array([list(a) * 110]) + b100 = np.array([list(b) * 110]) + # This is not accurate to more than 2 places + assert_almost_equal(ami, 0.37, 2) From 77ec530a699d5925ff4019fd0b552a37dc658c0f Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Fri, 21 Oct 2011 20:06:56 +1100 Subject: [PATCH 12/30] Test actually does what I meant it to do, and works sufficiently --- sklearn/metrics/cluster/tests/test_supervised.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/cluster/tests/test_supervised.py b/sklearn/metrics/cluster/tests/test_supervised.py index 0ede7dd9f8a1f..b47badacb7969 100644 --- a/sklearn/metrics/cluster/tests/test_supervised.py +++ b/sklearn/metrics/cluster/tests/test_supervised.py @@ -153,7 +153,8 @@ def test_ami_score(): ami = ami_score([1, 1, 2, 2], [2, 2, 3, 3]) assert_equal(ami, 1.0) # Test with a very large array - a100 = np.array([list(a) * 110]) - b100 = np.array([list(b) * 110]) + a110 = np.array([list(labels_a) * 110]).flatten() + b110 = np.array([list(labels_b) * 110]).flatten() + ami = ami_score(a110, b110) # This is not accurate to more than 2 places assert_almost_equal(ami, 0.37, 2) From 52fc4c57515848609100ba920a6fb08150e01e2b Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Fri, 21 Oct 2011 21:03:05 +1100 Subject: [PATCH 13/30] Fixed this example. Tested the others (they worked!) --- examples/cluster/plot_kmeans_digits.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/cluster/plot_kmeans_digits.py b/examples/cluster/plot_kmeans_digits.py index 0ccc3b9e230ee..e0e0e02b165cc 100644 --- a/examples/cluster/plot_kmeans_digits.py +++ b/examples/cluster/plot_kmeans_digits.py @@ -49,7 +49,7 @@ print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels, km.labels_) print "Adjusted Mutual Information: %0.3f" % \ - metrics.ami_score(labels_true, labels) + metrics.ami_score(labels, km.labels_) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, km.labels_, metric='euclidean', sample_size=sample_size)) @@ -65,8 +65,8 @@ print "V-measure: %0.3f" % metrics.v_measure_score(labels, km.labels_) print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels, km.labels_) -#print ("Silhouette Coefficient: %0.3f" % -# metrics.silhouette_score(D, km.labels_, metric='precomputed')) +print "Adjusted Mutual Information: %0.3f" % \ + metrics.ami_score(labels, km.labels_) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, km.labels_, metric='euclidean', sample_size=sample_size)) @@ -85,8 +85,8 @@ print "V-measure: %0.3f" % metrics.v_measure_score(labels, km.labels_) print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels, km.labels_) -#print ("Silhouette Coefficient: %0.3f" % -# metrics.silhouette_score(D, km.labels_, metric='precomputed')) +print "Adjusted Mutual Information: %0.3f" % \ + metrics.ami_score(labels, km.labels_) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, km.labels_, metric='euclidean', sample_size=sample_size)) From c021b59101f5f219b58ced5d7b9158c8cbdd3e21 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Fri, 21 Oct 2011 21:07:29 +1100 Subject: [PATCH 14/30] pep8 and pyflakes --- examples/cluster/plot_color_quantization.py | 4 ++-- sklearn/metrics/cluster/__init__.py | 1 - sklearn/metrics/cluster/supervised.py | 13 ------------- 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/examples/cluster/plot_color_quantization.py b/examples/cluster/plot_color_quantization.py index 518376e86b816..d2af736e242bb 100644 --- a/examples/cluster/plot_color_quantization.py +++ b/examples/cluster/plot_color_quantization.py @@ -10,7 +10,7 @@ In this example, pixels are represented in a 3D-space and K-means is used to find 64 color clusters. In the image processing literature, the codebook -obtained from K-means (the cluster centers) is called the color palette. Using a +obtained from K-means (the cluster centers) is called the color palette. Using a single byte, up to 256 colors can be addressed, whereas an RGB encoding requires 3 bytes per pixel. The GIF file format, for example, uses such a palette. @@ -61,7 +61,7 @@ print "done in %0.3fs." % (time() - t0) -codebook_random = shuffle(image_array, random_state=0)[:n_colors+1] +codebook_random = shuffle(image_array, random_state=0)[:n_colors + 1] print "Predicting color indices on the full image (random)" t0 = time() dist = euclidean_distances(codebook_random, image_array, squared=True) diff --git a/sklearn/metrics/cluster/__init__.py b/sklearn/metrics/cluster/__init__.py index 5f42b3a011787..63c4129ac2592 100644 --- a/sklearn/metrics/cluster/__init__.py +++ b/sklearn/metrics/cluster/__init__.py @@ -11,4 +11,3 @@ ami_score, expected_mutual_information, mutual_information_score, contingency_matrix) from unsupervised import silhouette_score, silhouette_samples - diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 2580225d90b8b..51ce647437ac6 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -635,27 +635,16 @@ def expected_mutual_information(contingency, n_samples): end = int(min(a[i], b[j]) + 1) for nij in range(start, end): term1 = nij / float(N) - assert np.isfinite(term1), ("term1: %r, nij=%d, N=%d" % - (term1, nij, N)) term2 = np.log(float(N * nij) / (a[i] * b[j])) - assert np.isfinite(term2), ("term2: %r, nij=%d, N=%d, a[i]=%d, b[j]=%d" % - (term2, nij, N, a[i], b[j])) - # a! / (a - n)! term3a = np.multiply.reduce(range(a[i] - nij + 1, a[i] + 1)) - assert np.isfinite(term3a), ("term3a: %r, a[i]=%d, nij=%d" % - (term3a, a[i], nij)) # b! / (b - n)! term3b = np.multiply.reduce(range(b[j] - nij + 1, b[j] + 1)) - assert np.isfinite(term3b), ("term3b: %r, b[j]=%d, nij=%d" % - (term3b, b[j], nij)) # (N - a)! / N! t = np.multiply.reduce(range(N - a[i] + 1, N + 1)) if t == 0: continue term3c = 1. / t - assert np.isfinite(term3c), ("term3c: %r, a[i]=%d, N=%d, t=%.3f" % - (term3c, a[i], N, t)) # (N - b)! / (N - a - b - n)! num3d = N - b[j] + 1 den3d = N - a[i] - b[j] + nij + 1 @@ -664,8 +653,6 @@ def expected_mutual_information(contingency, n_samples): else: term3d = np.multiply.reduce(range(num3d, den3d)) term3d = 1. / term3d - assert np.isfinite(term3d), ("term3d: %r, N=%d, a[i]=%d, b[j]=%d, nij=%d" % - (term3d, N, a[i], b[j], nij)) # 1 / n! term3e = 1. / factorial(nij) # Add the product of all terms From b7d7642241635dcadd5248f32433912d10134541 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sun, 23 Oct 2011 10:43:23 +0200 Subject: [PATCH 15/30] measure runtimes for various clustering metrics in adjusted for chance example --- examples/cluster/plot_adjusted_for_chance_measures.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/cluster/plot_adjusted_for_chance_measures.py b/examples/cluster/plot_adjusted_for_chance_measures.py index 8f05747d0fd9c..f8e8d77d1306b 100644 --- a/examples/cluster/plot_adjusted_for_chance_measures.py +++ b/examples/cluster/plot_adjusted_for_chance_measures.py @@ -27,11 +27,12 @@ import numpy as np import pylab as pl +from time import time from sklearn import metrics def uniform_labelings_scores(score_func, n_samples, n_clusters_range, - fixed_n_classes=None, n_runs=10, seed=42): + fixed_n_classes=None, n_runs=5, seed=42): """Compute score for 2 random uniform cluster labelings. Both random labelings have the same number of clusters for each value @@ -74,7 +75,9 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, print "Computing %s for %d values of n_clusters and n_samples=%d" % ( score_func.__name__, len(n_clusters_range), n_samples) + t0 = time() scores = uniform_labelings_scores(score_func, n_samples, n_clusters_range) + print "done in %0.3fs" % (time() - t0) plots.append(pl.errorbar( n_clusters_range, scores.mean(axis=1), scores.std(axis=1))) names.append(score_func.__name__) @@ -102,8 +105,10 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, print "Computing %s for %d values of n_clusters and n_samples=%d" % ( score_func.__name__, len(n_clusters_range), n_samples) + t0 = time() scores = uniform_labelings_scores(score_func, n_samples, n_clusters_range, fixed_n_classes=n_classes) + print "done in %0.3fs" % (time() - t0) plots.append(pl.errorbar( n_clusters_range, scores.mean(axis=1), scores.std(axis=1))) names.append(score_func.__name__) From 3845fd8900884c9a1345b1cbc36723f00e727097 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sun, 23 Oct 2011 10:44:06 +0200 Subject: [PATCH 16/30] FIX warnings by avoiding 0.0 values in the log + cosmit --- sklearn/metrics/cluster/supervised.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 51ce647437ac6..3a6efc8addc33 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -539,8 +539,10 @@ def mutual_information_score(labels_true, labels_pred, contingency=None): pi /= np.sum(pi) pj = np.sum(contingency, axis=0) pj /= np.sum(pj) - mi = contingency * np.log(contingency / np.outer(pi, pj)) - return np.sum(mi[np.isfinite(mi)]) + outer = np.outer(pi, pj) + nnz = contingency != 0.0 + mi = contingency[nnz] * np.log(contingency[nnz] / outer[nnz]) + return mi.sum() def ami_score(labels_true, labels_pred): @@ -584,7 +586,8 @@ def ami_score(labels_true, labels_pred): Examples -------- - Perfect labelings are both homogeneous and complete, hence have score 1.0:: + Perfect labelings are both homogeneous and complete, hence have + score 1.0:: >>> from sklearn.metrics.cluster import ami_score >>> ami_score([0, 0, 1, 1], [0, 0, 1, 1]) @@ -592,7 +595,6 @@ def ami_score(labels_true, labels_pred): >>> ami_score([0, 0, 1, 1], [1, 1, 0, 0]) 1.0 - If classes members are completly splitted accross different clusters, the assignment is totally in-complete, hence the AMI is null:: From 77ff09808aa8f646bd95086c636b9719be5b3d1d Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Sun, 23 Oct 2011 21:24:55 +1100 Subject: [PATCH 17/30] Optimising the expected mutual information code Still a long way to go, but I've halved the time from approx 30 seconds to 16 seconds for an example --- .../plot_adjusted_for_chance_measures.py | 6 ++- sklearn/metrics/cluster/supervised.py | 42 +++++++++++++------ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/examples/cluster/plot_adjusted_for_chance_measures.py b/examples/cluster/plot_adjusted_for_chance_measures.py index f8e8d77d1306b..ca053a0b2acac 100644 --- a/examples/cluster/plot_adjusted_for_chance_measures.py +++ b/examples/cluster/plot_adjusted_for_chance_measures.py @@ -59,7 +59,8 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, score_funcs = [ metrics.adjusted_rand_score, metrics.v_measure_score, - metrics.ami_score + metrics.ami_score, + metrics.mutual_information_score, ] # 2 independent random clusterings with equal cluster number @@ -79,7 +80,8 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, scores = uniform_labelings_scores(score_func, n_samples, n_clusters_range) print "done in %0.3fs" % (time() - t0) plots.append(pl.errorbar( - n_clusters_range, scores.mean(axis=1), scores.std(axis=1))) + # n_clusters_range, scores.mean(axis=1), scores.std(axis=1))) + n_clusters_range, np.median(scores, axis=1), scores.std(axis=1))) names.append(score_func.__name__) pl.title("Clustering measures for 2 random uniform labelings\n" diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 3a6efc8addc33..17a83a7e6817b 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -630,36 +630,52 @@ def expected_mutual_information(contingency, n_samples): N = n_samples a = np.sum(contingency, axis=1, dtype='int') b = np.sum(contingency, axis=0, dtype='int') + # While nijs[0] will never be used, having it simplifies the indexing. + nijs = np.arange(0, max(np.max(a), np.max(b)) + 1, dtype='float') + # Compute values that are used multiple times. + # term1 is nij / N + term1 = nijs / N + # term2 uses the outer product + ab_outer = np.outer(a, b) + # term2 uses N * nij + Nnij = N * nijs + # term3a (was np.multiply.reduce(range(a[i] - nij + 1, a[i] + 1))) + # term3 has a component: 1 / n! + term3e = 1. / factorial(nijs) + # numerator for term 3d + num3d = N - b + 1 emi = 0 for i in range(R): for j in range(C): start = int(max(a[i] + b[j] - N, 1)) end = int(min(a[i], b[j]) + 1) for nij in range(start, end): - term1 = nij / float(N) - term2 = np.log(float(N * nij) / (a[i] * b[j])) + term2 = np.log(Nnij[nij] / ab_outer[i][j]) # a! / (a - n)! - term3a = np.multiply.reduce(range(a[i] - nij + 1, a[i] + 1)) + term3a = np.multiply.reduce(np.arange(a[i] - nij + 1, a[i] + 1, + dtype='int')) # b! / (b - n)! - term3b = np.multiply.reduce(range(b[j] - nij + 1, b[j] + 1)) + term3b = np.multiply.reduce(np.arange(b[j] - nij + 1, b[j] + 1, + dtype='int')) # (N - a)! / N! - t = np.multiply.reduce(range(N - a[i] + 1, N + 1)) + t = np.multiply.reduce(np.arange(N - a[i] + 1, N + 1, + dtype='int')) if t == 0: continue term3c = 1. / t # (N - b)! / (N - a - b - n)! - num3d = N - b[j] + 1 + num3dj = num3d[j] den3d = N - a[i] - b[j] + nij + 1 - if num3d > den3d: - term3d = np.multiply.reduce(range(den3d, num3d)) + if num3dj > den3d: + term3d = np.multiply.reduce(np.arange(den3d, num3dj, + dtype='int')) else: - term3d = np.multiply.reduce(range(num3d, den3d)) + term3d = np.multiply.reduce(np.arange(num3dj, den3d, + dtype='int')) term3d = 1. / term3d - # 1 / n! - term3e = 1. / factorial(nij) # Add the product of all terms - emi += (term1 * term2 * term3a * term3b - * term3c * term3d * term3e) + emi += (term1[nij] * term2 * term3a * term3b + * term3c * term3d * term3e[nij]) return emi From 13a48df93547fa048a9d0b0b5a108be98c061226 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Sat, 29 Oct 2011 15:23:01 +1100 Subject: [PATCH 18/30] Adding old version of EMI, as I'm about to change it --- sklearn/metrics/cluster/supervised.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 17a83a7e6817b..4a0a2d4a4c8ce 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -659,7 +659,7 @@ def expected_mutual_information(contingency, n_samples): dtype='int')) # (N - a)! / N! t = np.multiply.reduce(np.arange(N - a[i] + 1, N + 1, - dtype='int')) + dtype='float')) if t == 0: continue term3c = 1. / t From ccc3cad76d5376623cdcf6682e2c5b0087a1b24d Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Wed, 2 Nov 2011 20:26:52 +1100 Subject: [PATCH 19/30] This version doesn't work either. I am uploading for historical sake. I'll be undoing these changes with the gammaln function to see how that goes. --- .../plot_adjusted_for_chance_measures.py | 8 +++ sklearn/metrics/cluster/supervised.py | 65 ++++++++++--------- 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/examples/cluster/plot_adjusted_for_chance_measures.py b/examples/cluster/plot_adjusted_for_chance_measures.py index ca053a0b2acac..c579c7cd78113 100644 --- a/examples/cluster/plot_adjusted_for_chance_measures.py +++ b/examples/cluster/plot_adjusted_for_chance_measures.py @@ -91,6 +91,10 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, pl.legend(plots, names) pl.ylim(ymin=-0.05, ymax=1.05) pl.show() +# The new version of matplotlib on my computer has a bug where show does nothing +# Sorry if I forget to take this out! +pl.savefig("adj1.png") + # Random labeling with varying n_clusters against ground class labels # with fixed number of clusters @@ -122,3 +126,7 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, pl.ylim(ymin=-0.05, ymax=1.05) pl.legend(plots, names) pl.show() + +# The new version of matplotlib on my computer has a bug where show does nothing +# Sorry if I forget to take this out! +pl.savefig("adj2.png") diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 4a0a2d4a4c8ce..c812af5574741 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -624,8 +624,36 @@ def ami_score(labels_true, labels_pred): return ami +def _accumulate_factorials(all_factors): + """Calculates the product of the given factorials. + + This function solves equations of the form: + + \frac{\prod{n_1!, n_2!,...n_k!}}{\prod{d_1!, d_2!,...d_m!}} + + Parameters + ---------- + all_factors: list of signed integers + Integers are positive if they are a numerator in the equation, + and negative if they are in the denominator. + + Returns + ------- + t: float + The product of the factorial of each of the given values that are + positive divided by the product of each of the given values that are + negative. + """ + a = np.zeros((all_factors.max() + 1)) + b = np.arange(all_factors.max() + 1, dtype='float') + for factor in all_factors: + a[:(abs(factor) + 1)] += np.sign(factor) + return np.exp(np.sum(np.dot(a[2:], np.log(b[2:])))) + + def expected_mutual_information(contingency, n_samples): """Calculate the expected mutual information for two labelings.""" + np.seterr(all='raise') R, C = contingency.shape N = n_samples a = np.sum(contingency, axis=1, dtype='int') @@ -639,43 +667,20 @@ def expected_mutual_information(contingency, n_samples): ab_outer = np.outer(a, b) # term2 uses N * nij Nnij = N * nijs - # term3a (was np.multiply.reduce(range(a[i] - nij + 1, a[i] + 1))) - # term3 has a component: 1 / n! - term3e = 1. / factorial(nijs) - # numerator for term 3d - num3d = N - b + 1 emi = 0 for i in range(R): for j in range(C): start = int(max(a[i] + b[j] - N, 1)) end = int(min(a[i], b[j]) + 1) for nij in range(start, end): - term2 = np.log(Nnij[nij] / ab_outer[i][j]) - # a! / (a - n)! - term3a = np.multiply.reduce(np.arange(a[i] - nij + 1, a[i] + 1, - dtype='int')) - # b! / (b - n)! - term3b = np.multiply.reduce(np.arange(b[j] - nij + 1, b[j] + 1, - dtype='int')) - # (N - a)! / N! - t = np.multiply.reduce(np.arange(N - a[i] + 1, N + 1, - dtype='float')) - if t == 0: - continue - term3c = 1. / t - # (N - b)! / (N - a - b - n)! - num3dj = num3d[j] - den3d = N - a[i] - b[j] + nij + 1 - if num3dj > den3d: - term3d = np.multiply.reduce(np.arange(den3d, num3dj, - dtype='int')) - else: - term3d = np.multiply.reduce(np.arange(num3dj, den3d, - dtype='int')) - term3d = 1. / term3d + term1 = nij / N + term2 = np.log(N * nij) - np.log(a[i] * b[j]) + factors = np.array([a[i], b[j], (N-a[i]), (N-b[j]), + -N, -nij, -(a[i] - nij), -(b[j] - nij), + -(N - a[i] - b[j] + nij)]) + term3 = _accumulate_factorials(factors) # Add the product of all terms - emi += (term1[nij] * term2 * term3a * term3b - * term3c * term3d * term3e[nij]) + emi += (term1 * term2 * term3) return emi From fb6fe360ba1c521b5742aa6f49fc4512c479feb2 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Wed, 2 Nov 2011 21:31:09 +1100 Subject: [PATCH 20/30] Initial usage of gammaln. Not yet tested --- sklearn/metrics/cluster/supervised.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index c812af5574741..222dfb8b1433b 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -9,6 +9,7 @@ from math import log from scipy.misc import comb, factorial +from scipy.special import gammaln import numpy as np @@ -655,7 +656,7 @@ def expected_mutual_information(contingency, n_samples): """Calculate the expected mutual information for two labelings.""" np.seterr(all='raise') R, C = contingency.shape - N = n_samples + N = float(n_samples) a = np.sum(contingency, axis=1, dtype='int') b = np.sum(contingency, axis=0, dtype='int') # While nijs[0] will never be used, having it simplifies the indexing. @@ -678,7 +679,9 @@ def expected_mutual_information(contingency, n_samples): factors = np.array([a[i], b[j], (N-a[i]), (N-b[j]), -N, -nij, -(a[i] - nij), -(b[j] - nij), -(N - a[i] - b[j] + nij)]) - term3 = _accumulate_factorials(factors) + gln = gammaln(np.abs(factors) + 1) # n! = gamma(n-1) + signs = np.sign(factors) + term3 = np.exp(np.sum(np.dot(signs, gln))) # Add the product of all terms emi += (term1 * term2 * term3) return emi From 75bf5b0da5359cf40fa6d02030b035e53a88681a Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Sat, 5 Nov 2011 13:48:11 +1100 Subject: [PATCH 21/30] Still overflows, but the closest so far. Using gammaln --- sklearn/metrics/cluster/supervised.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 560b2e6bb37b1..3987019f80bd7 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -669,19 +669,27 @@ def expected_mutual_information(contingency, n_samples): # term2 uses N * nij Nnij = N * nijs emi = 0 + # signs is the same for all loops + signs = np.array([ 1, 1, 1, 1, -1, -1, -1, -1, -1]) for i in range(R): for j in range(C): start = int(max(a[i] + b[j] - N, 1)) end = int(min(a[i], b[j]) + 1) for nij in range(start, end): - term1 = nij / N + term1 = nij / N # Moved '/ N' to reduce term 3 term2 = np.log(N * nij) - np.log(a[i] * b[j]) factors = np.array([a[i], b[j], (N-a[i]), (N-b[j]), -N, -nij, -(a[i] - nij), -(b[j] - nij), -(N - a[i] - b[j] + nij)]) gln = gammaln(np.abs(factors) + 1) # n! = gamma(n-1) - signs = np.sign(factors) - term3 = np.exp(np.sum(np.dot(signs, gln))) + try: + ev = np.exp(np.multiply(signs, gln)) + term3 = np.multiply.reduce(ev) + except: + print gln + print factors + print signs + raise # Add the product of all terms emi += (term1 * term2 * term3) return emi From 01ae29edeae88440ad206abed18de2749865349f Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Sat, 5 Nov 2011 15:07:33 +1100 Subject: [PATCH 22/30] It works! Still have some optimisation to do, but it works for larger arrays --- sklearn/metrics/cluster/supervised.py | 35 ++++++++++++--------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 3987019f80bd7..9eefbe6dec1f7 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -661,37 +661,34 @@ def expected_mutual_information(contingency, n_samples): b = np.sum(contingency, axis=0, dtype='int') # While nijs[0] will never be used, having it simplifies the indexing. nijs = np.arange(0, max(np.max(a), np.max(b)) + 1, dtype='float') + nijs[0] = 1 # stops divide by zero problems # Compute values that are used multiple times. # term1 is nij / N term1 = nijs / N # term2 uses the outer product - ab_outer = np.outer(a, b) + log_ab_outer = np.log(np.outer(a, b)) # term2 uses N * nij - Nnij = N * nijs + log_Nnij = np.log(N * nijs) + # term3 uses these factors + gln_a = gammaln(a) + gln_b = gammaln(b) + gln_Na = gammaln(N-a) + gln_Nb = gammaln(N-b) + gln_N = gammaln(N) + gln_nij = gammaln(nijs) emi = 0 - # signs is the same for all loops - signs = np.array([ 1, 1, 1, 1, -1, -1, -1, -1, -1]) for i in range(R): for j in range(C): start = int(max(a[i] + b[j] - N, 1)) end = int(min(a[i], b[j]) + 1) for nij in range(start, end): - term1 = nij / N # Moved '/ N' to reduce term 3 - term2 = np.log(N * nij) - np.log(a[i] * b[j]) - factors = np.array([a[i], b[j], (N-a[i]), (N-b[j]), - -N, -nij, -(a[i] - nij), -(b[j] - nij), - -(N - a[i] - b[j] + nij)]) - gln = gammaln(np.abs(factors) + 1) # n! = gamma(n-1) - try: - ev = np.exp(np.multiply(signs, gln)) - term3 = np.multiply.reduce(ev) - except: - print gln - print factors - print signs - raise + term2 = log_Nnij[nij] - log_ab_outer[i][j] + gln = (gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j] + - gln_N - gln_nij[nij] - gammaln(a[i] - nij) - gammaln(b[j] - nij) + - gammaln(N - a[i] - b[j] + nij)) + term3 = np.exp(gln) # Add the product of all terms - emi += (term1 * term2 * term3) + emi += (term1[nij] * term2 * term3) return emi From f4073722f13a8fc730f5659e9729bdd6f80eee56 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Sun, 6 Nov 2011 21:28:14 +1100 Subject: [PATCH 23/30] Moved start and finish outside of loop --- sklearn/metrics/cluster/supervised.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 9eefbe6dec1f7..4fe18af4767c2 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -670,22 +670,25 @@ def expected_mutual_information(contingency, n_samples): # term2 uses N * nij log_Nnij = np.log(N * nijs) # term3 uses these factors - gln_a = gammaln(a) - gln_b = gammaln(b) - gln_Na = gammaln(N-a) - gln_Nb = gammaln(N-b) - gln_N = gammaln(N) - gln_nij = gammaln(nijs) + gln_a = gammaln(a+1) + gln_b = gammaln(b+1) + gln_Na = gammaln(N-a+1) + gln_Nb = gammaln(N-b+1) + gln_N = gammaln(N+1) + gln_nij = gammaln(nijs+1) + # start and end values for nij terms + start = np.array([[v - N + w for w in b] for v in a], dtype='int') + start = np.maximum(start, 1) + end = np.minimum(np.resize(a, (C, R)).T, np.resize(b, (R, C))) + 1 emi = 0 for i in range(R): for j in range(C): - start = int(max(a[i] + b[j] - N, 1)) - end = int(min(a[i], b[j]) + 1) - for nij in range(start, end): + for nij in range(start[i][j], end[i][j]): term2 = log_Nnij[nij] - log_ab_outer[i][j] gln = (gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j] - - gln_N - gln_nij[nij] - gammaln(a[i] - nij) - gammaln(b[j] - nij) - - gammaln(N - a[i] - b[j] + nij)) + - gln_N - gln_nij[nij] - gammaln(a[i] - nij+1) + - gammaln(b[j] - nij+1) + - gammaln(N - a[i] - b[j] + nij+1)) term3 = np.exp(gln) # Add the product of all terms emi += (term1[nij] * term2 * term3) From 88adf3b9ab471b3b51c2aa59e8ca5ac4fc74d442 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Mon, 7 Nov 2011 21:10:53 +1100 Subject: [PATCH 24/30] comments, pep8 and pyflakes --- .../plot_adjusted_for_chance_measures.py | 7 ---- sklearn/metrics/cluster/supervised.py | 38 ++++++++++--------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/examples/cluster/plot_adjusted_for_chance_measures.py b/examples/cluster/plot_adjusted_for_chance_measures.py index c579c7cd78113..ab4d644bc590c 100644 --- a/examples/cluster/plot_adjusted_for_chance_measures.py +++ b/examples/cluster/plot_adjusted_for_chance_measures.py @@ -91,9 +91,6 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, pl.legend(plots, names) pl.ylim(ymin=-0.05, ymax=1.05) pl.show() -# The new version of matplotlib on my computer has a bug where show does nothing -# Sorry if I forget to take this out! -pl.savefig("adj1.png") # Random labeling with varying n_clusters against ground class labels @@ -126,7 +123,3 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, pl.ylim(ymin=-0.05, ymax=1.05) pl.legend(plots, names) pl.show() - -# The new version of matplotlib on my computer has a bug where show does nothing -# Sorry if I forget to take this out! -pl.savefig("adj2.png") diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 4fe18af4767c2..03a9c0eec9498 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -8,7 +8,7 @@ # License: BSD Style. from math import log -from scipy.misc import comb, factorial +from scipy.misc import comb from scipy.special import gammaln import numpy as np @@ -654,43 +654,47 @@ def _accumulate_factorials(all_factors): def expected_mutual_information(contingency, n_samples): """Calculate the expected mutual information for two labelings.""" - np.seterr(all='raise') R, C = contingency.shape N = float(n_samples) a = np.sum(contingency, axis=1, dtype='int') b = np.sum(contingency, axis=0, dtype='int') + # There are three major terms to the EMI equation, which are multiplied to + # and then summed over varying nij values. # While nijs[0] will never be used, having it simplifies the indexing. nijs = np.arange(0, max(np.max(a), np.max(b)) + 1, dtype='float') - nijs[0] = 1 # stops divide by zero problems - # Compute values that are used multiple times. + nijs[0] = 1 # Stops divide by zero warnings. As its not used, no issue. # term1 is nij / N term1 = nijs / N + # term2 is log((N*nij) / (a * b)) == log(N * nij) - log(a * b) # term2 uses the outer product log_ab_outer = np.log(np.outer(a, b)) # term2 uses N * nij log_Nnij = np.log(N * nijs) - # term3 uses these factors - gln_a = gammaln(a+1) - gln_b = gammaln(b+1) - gln_Na = gammaln(N-a+1) - gln_Nb = gammaln(N-b+1) - gln_N = gammaln(N+1) - gln_nij = gammaln(nijs+1) - # start and end values for nij terms + # term3 is large, and involved many factorials. Calculate these in log + # space to stop overflows. + gln_a = gammaln(a + 1) + gln_b = gammaln(b + 1) + gln_Na = gammaln(N - a + 1) + gln_Nb = gammaln(N - b + 1) + gln_N = gammaln(N + 1) + gln_nij = gammaln(nijs + 1) + # start and end values for nij terms for each summation. start = np.array([[v - N + w for w in b] for v in a], dtype='int') start = np.maximum(start, 1) end = np.minimum(np.resize(a, (C, R)).T, np.resize(b, (R, C))) + 1 + # emi itself is a summation over the various values. emi = 0 for i in range(R): for j in range(C): for nij in range(start[i][j], end[i][j]): - term2 = log_Nnij[nij] - log_ab_outer[i][j] + term2 = log_Nnij[nij] - log_ab_outer[i][j] + # Numerators are positive, denominators are negative. gln = (gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j] - - gln_N - gln_nij[nij] - gammaln(a[i] - nij+1) - - gammaln(b[j] - nij+1) - - gammaln(N - a[i] - b[j] + nij+1)) + - gln_N - gln_nij[nij] - gammaln(a[i] - nij + 1) + - gammaln(b[j] - nij + 1) + - gammaln(N - a[i] - b[j] + nij + 1)) term3 = np.exp(gln) - # Add the product of all terms + # Add the product of all terms. emi += (term1[nij] * term2 * term3) return emi From 2fa105292340a33a7481f959a760e2d5dcb028f3 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Mon, 7 Nov 2011 22:23:15 +1100 Subject: [PATCH 25/30] ami_score -> adjusted_mutual_info_score Warning about slowness of function --- sklearn/metrics/__init__.py | 2 +- sklearn/metrics/cluster/__init__.py | 5 ++- sklearn/metrics/cluster/supervised.py | 40 ++++--------------- .../metrics/cluster/tests/test_supervised.py | 12 +++--- 4 files changed, 18 insertions(+), 41 deletions(-) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index cfd56dc57ba19..6ac85eadf8d95 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -17,5 +17,5 @@ from .cluster import v_measure_score from .cluster import silhouette_score from .cluster import mutual_information_score -from .cluster import ami_score +from .cluster import adjusted_mutual_info_score from .pairwise import euclidean_distances, pairwise_distances, pairwise_kernels diff --git a/sklearn/metrics/cluster/__init__.py b/sklearn/metrics/cluster/__init__.py index 63c4129ac2592..50be53458a648 100644 --- a/sklearn/metrics/cluster/__init__.py +++ b/sklearn/metrics/cluster/__init__.py @@ -8,6 +8,7 @@ from supervised import (homogeneity_completeness_v_measure, homogeneity_score, completeness_score, v_measure_score, adjusted_rand_score, - ami_score, expected_mutual_information, - mutual_information_score, contingency_matrix) + adjusted_mutual_info_score, + expected_mutual_information, mutual_information_score, + contingency_matrix) from unsupervised import silhouette_score, silhouette_samples diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 03a9c0eec9498..81c49257a716f 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -546,7 +546,7 @@ def mutual_information_score(labels_true, labels_pred, contingency=None): return mi.sum() -def ami_score(labels_true, labels_pred): +def adjusted_mutual_info_score(labels_true, labels_pred): """Adjusted Mutual Information between two clusterings Adjusted Mutual Information (AMI) is an adjustement of the Mutual @@ -566,6 +566,9 @@ def ami_score(labels_true, labels_pred): measure the agreement of two independent label assignments strategies on the same dataset when the real ground truth is not known. + Be mindful that this function is an order of magnitude slower than other + metrics, such as the Adjusted Rand Index. + Parameters ---------- labels_true : int array, shape = [n_samples] @@ -590,16 +593,16 @@ def ami_score(labels_true, labels_pred): Perfect labelings are both homogeneous and complete, hence have score 1.0:: - >>> from sklearn.metrics.cluster import ami_score - >>> ami_score([0, 0, 1, 1], [0, 0, 1, 1]) + >>> from sklearn.metrics.cluster import adjusted_mutual_info_score + >>> adjusted_mutual_info_score([0, 0, 1, 1], [0, 0, 1, 1]) 1.0 - >>> ami_score([0, 0, 1, 1], [1, 1, 0, 0]) + >>> adjusted_mutual_info_score([0, 0, 1, 1], [1, 1, 0, 0]) 1.0 If classes members are completly splitted accross different clusters, the assignment is totally in-complete, hence the AMI is null:: - >>> ami_score([0, 0, 0, 0], [0, 1, 2, 3]) + >>> adjusted_mutual_info_score([0, 0, 0, 0], [0, 1, 2, 3]) 0.0 """ @@ -625,33 +628,6 @@ def ami_score(labels_true, labels_pred): return ami -def _accumulate_factorials(all_factors): - """Calculates the product of the given factorials. - - This function solves equations of the form: - - \frac{\prod{n_1!, n_2!,...n_k!}}{\prod{d_1!, d_2!,...d_m!}} - - Parameters - ---------- - all_factors: list of signed integers - Integers are positive if they are a numerator in the equation, - and negative if they are in the denominator. - - Returns - ------- - t: float - The product of the factorial of each of the given values that are - positive divided by the product of each of the given values that are - negative. - """ - a = np.zeros((all_factors.max() + 1)) - b = np.arange(all_factors.max() + 1, dtype='float') - for factor in all_factors: - a[:(abs(factor) + 1)] += np.sign(factor) - return np.exp(np.sum(np.dot(a[2:], np.log(b[2:])))) - - def expected_mutual_information(contingency, n_samples): """Calculate the expected mutual information for two labelings.""" R, C = contingency.shape diff --git a/sklearn/metrics/cluster/tests/test_supervised.py b/sklearn/metrics/cluster/tests/test_supervised.py index b47badacb7969..5ca1cc34395e7 100644 --- a/sklearn/metrics/cluster/tests/test_supervised.py +++ b/sklearn/metrics/cluster/tests/test_supervised.py @@ -5,7 +5,7 @@ from sklearn.metrics.cluster import completeness_score from sklearn.metrics.cluster import v_measure_score from sklearn.metrics.cluster import homogeneity_completeness_v_measure -from sklearn.metrics.cluster import ami_score +from sklearn.metrics.cluster import adjusted_mutual_info_score from sklearn.metrics.cluster import mutual_information_score from sklearn.metrics.cluster import expected_mutual_information from sklearn.metrics.cluster import contingency_matrix @@ -20,7 +20,7 @@ homogeneity_score, completeness_score, v_measure_score, - ami_score, + adjusted_mutual_info_score, ] @@ -135,7 +135,7 @@ def test_adjustment_for_chance(): assert_array_almost_equal(max_abs_scores, [0.02, 0.03, 0.03, 0.02], 2) -def test_ami_score(): +def test_adjusted_mutual_info_score(): """Compute the Adjusted Mutual Information and test against known values""" labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]) @@ -148,13 +148,13 @@ def test_ami_score(): emi = expected_mutual_information(C, n_samples) assert_almost_equal(emi, 0.15042, 5) # Adjusted mutual information - ami = ami_score(labels_a, labels_b) + ami = adjusted_mutual_info_score(labels_a, labels_b) assert_almost_equal(ami, 0.27502, 5) - ami = ami_score([1, 1, 2, 2], [2, 2, 3, 3]) + ami = adjusted_mutual_info_score([1, 1, 2, 2], [2, 2, 3, 3]) assert_equal(ami, 1.0) # Test with a very large array a110 = np.array([list(labels_a) * 110]).flatten() b110 = np.array([list(labels_b) * 110]).flatten() - ami = ami_score(a110, b110) + ami = adjusted_mutual_info_score(a110, b110) # This is not accurate to more than 2 places assert_almost_equal(ami, 0.37, 2) From e51e4d679bedbc8328e2ea86a3c6c72b83318874 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Mon, 7 Nov 2011 22:46:11 +1100 Subject: [PATCH 26/30] ami_score -> adjusted_mutual_info_score This time, in examples and docs! --- doc/modules/clustering.rst | 14 +++++++------- .../cluster/plot_adjusted_for_chance_measures.py | 2 +- examples/cluster/plot_affinity_propagation.py | 2 +- examples/cluster/plot_dbscan.py | 2 +- examples/cluster/plot_kmeans_digits.py | 6 +++--- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/doc/modules/clustering.rst b/doc/modules/clustering.rst index e1e6fc4e4102a..9119385011a03 100644 --- a/doc/modules/clustering.rst +++ b/doc/modules/clustering.rst @@ -488,34 +488,34 @@ measures the **agreement** of the two assignements, ignoring permutations >>> labels_true = [0, 0, 0, 1, 1, 1] >>> labels_pred = [0, 0, 1, 1, 2, 2] - >>> metrics.ami_score(labels_true, labels_pred) # doctest: +ELLIPSIS + >>> metrics.adjusted_mutual_info_score(labels_true, labels_pred) # doctest: +ELLIPSIS 0.24... One can permute 0 and 1 in the predicted labels and rename `2` by `3` and get the same score:: >>> labels_pred = [1, 1, 0, 0, 3, 3] - >>> metrics.ami_score(labels_true, labels_pred) # doctest: +ELLIPSIS + >>> metrics.adjusted_mutual_info_score(labels_true, labels_pred) # doctest: +ELLIPSIS 0.24... -Furthermore, :func:`ami_score` is **symmetric**: swapping the argument -does not change the score. It can thus be used as a **consensus +Furthermore, :func:`adjusted_mutual_info_score` is **symmetric**: swapping the +argument does not change the score. It can thus be used as a **consensus measure**:: - >>> metrics.ami_score(labels_pred, labels_true) # doctest: +ELLIPSIS + >>> metrics.adjusted_mutual_info_score(labels_pred, labels_true) # doctest: +ELLIPSIS 0.24... Perfect labeling is scored 1.0:: >>> labels_pred = labels_true[:] - >>> metrics.ami_score(labels_true, labels_pred) + >>> metrics.adjusted_mutual_info_score(labels_true, labels_pred) 1.0 Bad (e.g. independent labelings) have scores of zero:: >>> labels_true = [0, 1, 2, 0, 3, 4, 5, 1] >>> labels_pred = [1, 1, 0, 0, 2, 2, 2, 2] - >>> metrics.ami_score(labels_true, labels_pred) # doctest: +ELLIPSIS + >>> metrics.adjusted_mutual_info_score(labels_true, labels_pred) # doctest: +ELLIPSIS 0.0... diff --git a/examples/cluster/plot_adjusted_for_chance_measures.py b/examples/cluster/plot_adjusted_for_chance_measures.py index ab4d644bc590c..9cdf9573c4ca1 100644 --- a/examples/cluster/plot_adjusted_for_chance_measures.py +++ b/examples/cluster/plot_adjusted_for_chance_measures.py @@ -59,7 +59,7 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, score_funcs = [ metrics.adjusted_rand_score, metrics.v_measure_score, - metrics.ami_score, + metrics.adjusted_mutual_info_score, metrics.mutual_information_score, ] diff --git a/examples/cluster/plot_affinity_propagation.py b/examples/cluster/plot_affinity_propagation.py index 9aedba4e3d970..d3d5ff6eb33d0 100644 --- a/examples/cluster/plot_affinity_propagation.py +++ b/examples/cluster/plot_affinity_propagation.py @@ -41,7 +41,7 @@ print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels_true, labels) print "Adjusted Mutual Information: %0.3f" % \ - metrics.ami_score(labels_true, labels) + metrics.adjusted_mutual_info_score(labels_true, labels) D = (S / np.min(S)) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(D, labels, metric='precomputed')) diff --git a/examples/cluster/plot_dbscan.py b/examples/cluster/plot_dbscan.py index 8b12f0ccd6f11..e6a686e064c47 100644 --- a/examples/cluster/plot_dbscan.py +++ b/examples/cluster/plot_dbscan.py @@ -42,7 +42,7 @@ print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels_true, labels) print "Adjusted Mutual Information: %0.3f" % \ - metrics.ami_score(labels_true, labels) + metrics.adjusted_mutual_info_score(labels_true, labels) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(D, labels, metric='precomputed')) diff --git a/examples/cluster/plot_kmeans_digits.py b/examples/cluster/plot_kmeans_digits.py index e0e0e02b165cc..811bfdd712b8e 100644 --- a/examples/cluster/plot_kmeans_digits.py +++ b/examples/cluster/plot_kmeans_digits.py @@ -49,7 +49,7 @@ print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels, km.labels_) print "Adjusted Mutual Information: %0.3f" % \ - metrics.ami_score(labels, km.labels_) + metrics.adjusted_mutual_info_score(labels, km.labels_) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, km.labels_, metric='euclidean', sample_size=sample_size)) @@ -66,7 +66,7 @@ print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels, km.labels_) print "Adjusted Mutual Information: %0.3f" % \ - metrics.ami_score(labels, km.labels_) + metrics.adjusted_mutual_info_score(labels, km.labels_) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, km.labels_, metric='euclidean', sample_size=sample_size)) @@ -86,7 +86,7 @@ print "Adjusted Rand Index: %0.3f" % \ metrics.adjusted_rand_score(labels, km.labels_) print "Adjusted Mutual Information: %0.3f" % \ - metrics.ami_score(labels, km.labels_) + metrics.adjusted_mutual_info_score(labels, km.labels_) print ("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, km.labels_, metric='euclidean', sample_size=sample_size)) From 66f7a0bca706142ae8d5b97c5d15be9ac61db807 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Tue, 8 Nov 2011 13:34:23 +1100 Subject: [PATCH 27/30] "What's new?" AMI! --- doc/whats_new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index dfc3deec0951e..deb4e7d042652 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -18,6 +18,9 @@ - Silhouette Coefficient cluster analysis evaluation metric added as ``sklearn.metrics.silhouette_score`` by Robert Layton. + - Adjusted Mutual Information metric added as + ``sklearn.metrics.adjusted_mutual_info_score`` by Robert Layton. + API changes summary ------------------- From f26ed764b14854641b2bc905d2098580ff66e951 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Thu, 10 Nov 2011 16:21:17 +1100 Subject: [PATCH 28/30] mutual_information_score -> mutual_info_score --- sklearn/metrics/__init__.py | 2 +- sklearn/metrics/cluster/__init__.py | 2 +- sklearn/metrics/cluster/supervised.py | 6 +++--- sklearn/metrics/cluster/tests/test_supervised.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 6ac85eadf8d95..c8123b2343573 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -16,6 +16,6 @@ from .cluster import completeness_score from .cluster import v_measure_score from .cluster import silhouette_score -from .cluster import mutual_information_score +from .cluster import mutual_info_score from .cluster import adjusted_mutual_info_score from .pairwise import euclidean_distances, pairwise_distances, pairwise_kernels diff --git a/sklearn/metrics/cluster/__init__.py b/sklearn/metrics/cluster/__init__.py index 50be53458a648..81d092039a49a 100644 --- a/sklearn/metrics/cluster/__init__.py +++ b/sklearn/metrics/cluster/__init__.py @@ -9,6 +9,6 @@ homogeneity_score, completeness_score, v_measure_score, adjusted_rand_score, adjusted_mutual_info_score, - expected_mutual_information, mutual_information_score, + expected_mutual_information, mutual_info_score, contingency_matrix) from unsupervised import silhouette_score, silhouette_samples diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 81c49257a716f..874c65ba0fac4 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -489,7 +489,7 @@ def v_measure_score(labels_true, labels_pred): return homogeneity_completeness_v_measure(labels_true, labels_pred)[2] -def mutual_information_score(labels_true, labels_pred, contingency=None): +def mutual_info_score(labels_true, labels_pred, contingency=None): """Adjusted Mutual Information between two clusterings The Mutual Information is a measure of the similarity between two labels @@ -618,8 +618,8 @@ def adjusted_mutual_info_score(labels_true, labels_pred): contingency = contingency_matrix(labels_true, labels_pred) contingency = np.array(contingency, dtype='float') # Calculate the MI for the two clusterings - mi = mutual_information_score(labels_true, labels_pred, - contingency=contingency) + mi = mutual_info_score(labels_true, labels_pred, + contingency=contingency) # Calculate the expected value for the mutual information emi = expected_mutual_information(contingency, n_samples) # Calculate entropy for each labeling diff --git a/sklearn/metrics/cluster/tests/test_supervised.py b/sklearn/metrics/cluster/tests/test_supervised.py index 5ca1cc34395e7..98399d9dc2288 100644 --- a/sklearn/metrics/cluster/tests/test_supervised.py +++ b/sklearn/metrics/cluster/tests/test_supervised.py @@ -6,7 +6,7 @@ from sklearn.metrics.cluster import v_measure_score from sklearn.metrics.cluster import homogeneity_completeness_v_measure from sklearn.metrics.cluster import adjusted_mutual_info_score -from sklearn.metrics.cluster import mutual_information_score +from sklearn.metrics.cluster import mutual_info_score from sklearn.metrics.cluster import expected_mutual_information from sklearn.metrics.cluster import contingency_matrix @@ -140,7 +140,7 @@ def test_adjusted_mutual_info_score(): labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]) # Mutual information - mi = mutual_information_score(labels_a, labels_b) + mi = mutual_info_score(labels_a, labels_b) assert_almost_equal(mi, 0.41022, 5) # Expected mutual information C = contingency_matrix(labels_a, labels_b) From 17ee6c41e845c1d8ab042a19fb447d455a90c640 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Thu, 10 Nov 2011 16:23:38 +1100 Subject: [PATCH 29/30] and in plot_adjusted example (mutual_info_score) --- examples/cluster/plot_adjusted_for_chance_measures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cluster/plot_adjusted_for_chance_measures.py b/examples/cluster/plot_adjusted_for_chance_measures.py index 9cdf9573c4ca1..cec0c1e739083 100644 --- a/examples/cluster/plot_adjusted_for_chance_measures.py +++ b/examples/cluster/plot_adjusted_for_chance_measures.py @@ -60,7 +60,7 @@ def uniform_labelings_scores(score_func, n_samples, n_clusters_range, metrics.adjusted_rand_score, metrics.v_measure_score, metrics.adjusted_mutual_info_score, - metrics.mutual_information_score, + metrics.mutual_info_score, ] # 2 independent random clusterings with equal cluster number From 118e8bde65ea9a24d648777401bd38af6ecd28d3 Mon Sep 17 00:00:00 2001 From: Robert Layton Date: Thu, 10 Nov 2011 22:00:18 +1100 Subject: [PATCH 30/30] cosmit --- doc/whats_new.rst | 4 ++-- sklearn/metrics/cluster/supervised.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 88adde130b04d..4dbef627a3e57 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -19,7 +19,7 @@ Changelog - Faster tests by `Fabian Pedregosa`_. - Silhouette Coefficient cluster analysis evaluation metric added as - ``sklearn.metrics.silhouette_score`` by `Robert Layton`. + ``sklearn.metrics.silhouette_score`` by `Robert Layton`_. - Fixed a bug in `KMeans` in the handling of the `n_init` parameter: the clustering algorithm used to be run `n_init` times but the last @@ -29,7 +29,7 @@ Changelog predict methods. - Adjusted Mutual Information metric added as - ``sklearn.metrics.adjusted_mutual_info_score`` by `Robert Layton`. + ``sklearn.metrics.adjusted_mutual_info_score`` by `Robert Layton`_. API changes summary diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 874c65ba0fac4..f253a9a5f9085 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -157,7 +157,7 @@ def adjusted_rand_score(labels_true, labels_pred): See also -------- - - ami_score: Adjusted Mutual Information + - adjusted_mutual_info_score: Adjusted Mutual Information """ labels_true, labels_pred = check_clusterings(labels_true, labels_pred) @@ -529,7 +529,7 @@ def mutual_info_score(labels_true, labels_pred, contingency=None): See also -------- - - ami_score: Adjusted Mutual Information + - adjusted_mutual_info_score: Adjusted Mutual Information """ if contingency is None: labels_true, labels_pred = check_clusterings(labels_true, labels_pred)