Skip to content

Commit

Permalink
Merge pull request GaelVaroquaux#2 from WeatherGod/square2rect
Browse files Browse the repository at this point in the history
Square2rect
  • Loading branch information
GaelVaroquaux committed Jun 6, 2011
2 parents 562d5b3 + 38f9a46 commit 372b125
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
23 changes: 12 additions & 11 deletions scikits/learn/utils/hungarian.py
Expand Up @@ -56,12 +56,13 @@ def compute(self, cost_matrix):
"""
self.C = cost_matrix.copy()
self.n = n = self.C.shape[0]
self.m = m = self.C.shape[1]
self.row_uncovered = np.ones(n, dtype=np.bool)
self.col_uncovered = np.ones(n, dtype=np.bool)
self.col_uncovered = np.ones(m, dtype=np.bool)
self.Z0_r = 0
self.Z0_c = 0
self.path = np.zeros((2*n, 2), dtype=int)
self.marked = np.zeros((n, n), dtype=int)
self.path = np.zeros((n+m, 2), dtype=int)
self.marked = np.zeros((n, m), dtype=int)

done = False
step = 1
Expand Down Expand Up @@ -111,7 +112,7 @@ def _step3(self):
marked = (self.marked == 1)
self.col_uncovered[np.any(marked, axis=0)] = False

if marked.sum() >= self.n:
if marked.sum() >= min(self.m, self.n) :
return 7 # done
else:
return 4
Expand All @@ -129,11 +130,10 @@ def _step4(self):
covered_C = C*self.row_uncovered[:, np.newaxis]
covered_C *= self.col_uncovered.astype(np.int)
n = self.n
m = self.m
while True:
# Find an uncovered zero
raveled_idx = np.argmax(covered_C)
col = raveled_idx % n
row = raveled_idx // n
row, col = np.unravel_index(np.argmax(covered_C), (n, m))
if covered_C[row, col] == 0:
return 6
else:
Expand Down Expand Up @@ -212,10 +212,11 @@ def _step6(self):
lines.
"""
# the smallest uncovered value in the matrix
minval = np.min(self.C[self.row_uncovered], axis=0)
minval = np.min(minval[self.col_uncovered])
self.C[np.logical_not(self.row_uncovered)] += minval
self.C[:, self.col_uncovered] -= minval
if np.any(self.row_uncovered) and np.any(self.col_uncovered):
minval = np.min(self.C[self.row_uncovered], axis=0)
minval = np.min(minval[self.col_uncovered])
self.C[np.logical_not(self.row_uncovered)] += minval
self.C[:, self.col_uncovered] -= minval
return 4

def _find_prime_in_row(self, row):
Expand Down
27 changes: 17 additions & 10 deletions scikits/learn/utils/tests/test_hungarian.py
Expand Up @@ -15,11 +15,11 @@ def test_hungarian():
),

## Rectangular variant
#([[400, 150, 400, 1],
# [400, 450, 600, 2],
# [300, 225, 300, 3]],
# 452 # expected cost
#),
([[400, 150, 400, 1],
[400, 450, 600, 2],
[300, 225, 300, 3]],
452 # expected cost
),

# Square
([[10, 10, 8],
Expand All @@ -29,11 +29,11 @@ def test_hungarian():
),

## Rectangular variant
#([[10, 10, 8, 11],
# [ 9, 8, 1, 1],
# [ 9, 7, 4, 10]],
# 15
#),
([[10, 10, 8, 11],
[ 9, 8, 1, 1],
[ 9, 7, 4, 10]],
15
),
]

m = _Hungarian()
Expand All @@ -54,3 +54,10 @@ def test_find_permutation():
np.testing.assert_array_equal(find_permutation(B, A),
np.arange(10)[::-1])


if __name__ == '__main__' :
print "find_permutations test..."
test_find_permutation()
print "Hungarian test..."
test_hungarian()

0 comments on commit 372b125

Please sign in to comment.