-
-
Notifications
You must be signed in to change notification settings - Fork 25.1k
/
vq_china.py
69 lines (60 loc) · 2.04 KB
/
vq_china.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# -*- coding: utf-8 -*-
"""
=========================================
Vector Quantization of Lena using k-means
=========================================
Performs a Vector Quatization of an image, reducing the
number of colors required to show the image.
"""
print __doc__
import os
import numpy as np
import pylab as pl
from scikits.learn.cluster import KMeans
from scikits.learn.datasets import load_sample_images
# Get all sample images and obtain just china.jpg
sample_image_name = "china.jpg"
sample_images = load_sample_images()
index = None
for i, filename in enumerate(sample_images.filenames):
if filename.endswith(sample_image_name):
index = i
break
if index is None:
raise AttributeError("Cannot find sample image: %s" % sample_image_name)
image_data = sample_images.images[index]
# Load Image and transform to a 2D numpy array.
w, h, d = original_shape = tuple(image_data.shape)
image_array = np.reshape(image_data, (w * h, 3))
# Take a sample of the data.
sample_indices = range(len(image_array))
np.random.shuffle(sample_indices)
sample_indices = sample_indices[:int(len(image_array) * 0.2)]
sample_data = image_array[sample_indices]
# Perform Vector Quantisation with 256 clusters.
k = 256
kmeans = KMeans(k=k)
kmeans.fit(sample_data)
# Get labels for all points
labels = kmeans.predict(image_array)
# Save the reduced dataset. Only the centroids and labels need to be saved.
reduced_image = (kmeans.cluster_centers_, labels)
def recreate_image(centroids, labels, w, h):
# Recreates the (compressed) image from centroids, labels and dimensions
d = len(centroids[0])
image = np.zeros((w, h, d))
label_num = 0
for i in range(w):
for j in range(h):
image[i][j] = centroids[labels[label_num]]
label_num += 1
print np.histogram(labels)
print set(labels)
return image
# Display all results, alongside original image
pl.figure()
ax = pl.axes([0,0,1,1], frameon=False)
ax.set_axis_off()
centroids, labels = reduced_image
im = pl.imshow(recreate_image(centroids, labels, w, h))
pl.show()