Skip to content

Commit

Permalink
Dataset loader moved to datasets.base, but not being installed
Browse files Browse the repository at this point in the history
  • Loading branch information
X006 committed Aug 22, 2011
1 parent 1bc54d4 commit 8811af0
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 11 deletions.
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ include scikits/__init__.py
recursive-include doc *
recursive-include examples *
recursive-include scikits *.c *.h *.pyx
recursive-include scikits/learn/datasets *.csv *.csv.gz *.TXT *.rst
recursive-include scikits/learn/datasets *.csv *.csv.gz *.TXT *.rst
recursive-include scikits/learn/datasets/images *.jpg *.txt
31 changes: 21 additions & 10 deletions examples/cluster/vq_china.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,26 @@
print __doc__

import numpy as np
from PIL import Image
import pylab as pl
from scikits.learn.cluster import KMeans
from pylab import figure, axes, imshow, show
from scikits.learn.datasets import load_sample_images
# Try to import Image and imresize from PIL. We do this here to prevent
# this module from depending on PIL.
try:
try:
from scipy.misc import Image
except ImportError:
from scipy.misc.pilutil import Image
except ImportError:
raise ImportError("The Python Imaging Library (PIL)"
"is required to load data from jpeg files")

# Load Image
filename = "/home/bob/Code/scikit-learn/scikits/learn/datasets/images/china.jpg"
# Get all sample images and obtain just china.jpg
sample_images = load_sample_images()
index = sample_images.filenames.index("china.jpg")
image_data = sample_images.images[index]

# Transform to numpy array
image_data = np.asarray(Image.open(filename))
# Load Image and transform to a 2D numpy array.
w, h, d = original_shape = tuple(image_data.shape)
image_array = np.reshape(image_data, (w * h, 3))

Expand All @@ -28,7 +39,7 @@
sample_indices = sample_indices[:int(len(image_array) * 0.5)]
sample_data = image_array[sample_indices]

# Perform Vector Quantisation with 256 clusters
# Perform Vector Quantisation with 256 clusters.
k = 256
kmeans = KMeans(k=k)
kmeans.fit(image_array)
Expand All @@ -50,11 +61,11 @@ def recreate_image(centroids, labels, w, h):
return image

# Display all results, alongside original image
figure()
ax = axes([0,0,1,1], frameon=False)
pl.figure()
ax = pl.axes([0,0,1,1], frameon=False)
ax.set_axis_off()
centroids, labels = reduced_image
im = imshow(recreate_image(centroids, labels, w, h))
im = pl.imshow(recreate_image(centroids, labels, w, h))

show()

1 change: 1 addition & 0 deletions scikits/learn/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .base import load_boston
from .base import get_data_home
from .base import clear_data_home
from .base import load_sample_images
from .mlcomp import load_mlcomp
from .lfw import load_lfw_pairs
from .lfw import load_lfw_people
Expand Down
36 changes: 36 additions & 0 deletions scikits/learn/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,39 @@ def load_boston():
feature_names=feature_names,
DESCR=fdescr.read())


def load_sample_images():
""" Load sample images for image manipulation.
Return
------
data : Bunch
Dictionary-like object, the interesting attributes are:
'data', the data to learn, `images`, the images corresponding
to each sample, 'target', the classification labels for each
sample, 'target_names', the meaning of the labels, and 'DESCR',
the full description of the dataset.
Examples
--------
To load the data and visualize the images::
>>> from scikits.learn.datasets import load_sample_images
>>> images = load_sample_images()
>>> # import pylab as pl
>>> # pl.gray()
>>> # pl.matshow(images.images[0]) # Visualize the first image
>>> # pl.show()
"""
module_path = join(dirname(__file__), "images")
descr = open(join(module_path, 'README.txt')).read()
filenames = [filename for filename in os.listdir(module_path)
if filename.endswith(".jpg")]
# Load image data for each image in the source folder.
images = [np.asarray(Image.open(filename))
for filename in filenames]
return Bunch(images=images,
filenames=filenames,
DESCR=descr)

0 comments on commit 8811af0

Please sign in to comment.