diff --git a/scikits/learn/utils/hungarian.py b/scikits/learn/utils/hungarian.py index 3e5affdcf82cf..c225e40cc6513 100644 --- a/scikits/learn/utils/hungarian.py +++ b/scikits/learn/utils/hungarian.py @@ -56,12 +56,13 @@ def compute(self, cost_matrix): """ self.C = cost_matrix.copy() self.n = n = self.C.shape[0] + self.m = m = self.C.shape[1] self.row_uncovered = np.ones(n, dtype=np.bool) - self.col_uncovered = np.ones(n, dtype=np.bool) + self.col_uncovered = np.ones(m, dtype=np.bool) self.Z0_r = 0 self.Z0_c = 0 - self.path = np.zeros((2*n, 2), dtype=int) - self.marked = np.zeros((n, n), dtype=int) + self.path = np.zeros((n+m, 2), dtype=int) + self.marked = np.zeros((n, m), dtype=int) done = False step = 1 @@ -111,7 +112,7 @@ def _step3(self): marked = (self.marked == 1) self.col_uncovered[np.any(marked, axis=0)] = False - if marked.sum() >= self.n: + if marked.sum() >= min(self.m, self.n) : return 7 # done else: return 4 @@ -129,11 +130,10 @@ def _step4(self): covered_C = C*self.row_uncovered[:, np.newaxis] covered_C *= self.col_uncovered.astype(np.int) n = self.n + m = self.m while True: # Find an uncovered zero - raveled_idx = np.argmax(covered_C) - col = raveled_idx % n - row = raveled_idx // n + row, col = np.unravel_index(np.argmax(covered_C), (n, m)) if covered_C[row, col] == 0: return 6 else: @@ -212,10 +212,11 @@ def _step6(self): lines. """ # the smallest uncovered value in the matrix - minval = np.min(self.C[self.row_uncovered], axis=0) - minval = np.min(minval[self.col_uncovered]) - self.C[np.logical_not(self.row_uncovered)] += minval - self.C[:, self.col_uncovered] -= minval + if np.any(self.row_uncovered) and np.any(self.col_uncovered): + minval = np.min(self.C[self.row_uncovered], axis=0) + minval = np.min(minval[self.col_uncovered]) + self.C[np.logical_not(self.row_uncovered)] += minval + self.C[:, self.col_uncovered] -= minval return 4 def _find_prime_in_row(self, row): diff --git a/scikits/learn/utils/tests/test_hungarian.py b/scikits/learn/utils/tests/test_hungarian.py index 7b6c8230b8802..f3cb5b4de5e07 100644 --- a/scikits/learn/utils/tests/test_hungarian.py +++ b/scikits/learn/utils/tests/test_hungarian.py @@ -15,11 +15,11 @@ def test_hungarian(): ), ## Rectangular variant - #([[400, 150, 400, 1], - # [400, 450, 600, 2], - # [300, 225, 300, 3]], - # 452 # expected cost - #), + ([[400, 150, 400, 1], + [400, 450, 600, 2], + [300, 225, 300, 3]], + 452 # expected cost + ), # Square ([[10, 10, 8], @@ -29,11 +29,11 @@ def test_hungarian(): ), ## Rectangular variant - #([[10, 10, 8, 11], - # [ 9, 8, 1, 1], - # [ 9, 7, 4, 10]], - # 15 - #), + ([[10, 10, 8, 11], + [ 9, 8, 1, 1], + [ 9, 7, 4, 10]], + 15 + ), ] m = _Hungarian() @@ -54,3 +54,10 @@ def test_find_permutation(): np.testing.assert_array_equal(find_permutation(B, A), np.arange(10)[::-1]) + +if __name__ == '__main__' : + print "find_permutations test..." + test_find_permutation() + print "Hungarian test..." + test_hungarian() +