diff --git a/MANIFEST.in b/MANIFEST.in index a94d4668afcce..e924b10bd379d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,5 +4,4 @@ include scikits/__init__.py recursive-include doc * recursive-include examples * recursive-include scikits *.c *.h *.pyx -recursive-include scikits/learn/datasets *.csv *.csv.gz *.TXT *.rst -recursive-include scikits/learn/datasets/images *.jpg *.txt +recursive-include scikits/learn/datasets *.csv *.csv.gz *.TXT *.rst *.jpg *.txt diff --git a/examples/cluster/vq_china.py b/examples/cluster/vq_china.py index 02a6bfe3fc65d..c0b048af1f932 100644 --- a/examples/cluster/vq_china.py +++ b/examples/cluster/vq_china.py @@ -8,25 +8,22 @@ number of colors required to show the image. """ print __doc__ - +import os import numpy as np import pylab as pl from scikits.learn.cluster import KMeans from scikits.learn.datasets import load_sample_images -# Try to import Image and imresize from PIL. We do this here to prevent -# this module from depending on PIL. -try: - try: - from scipy.misc import Image - except ImportError: - from scipy.misc.pilutil import Image -except ImportError: - raise ImportError("The Python Imaging Library (PIL)" - "is required to load data from jpeg files") # Get all sample images and obtain just china.jpg +sample_image_name = "china.jpg" sample_images = load_sample_images() -index = sample_images.filenames.index("china.jpg") +index = None +for i, filename in enumerate(sample_images.filenames): + if filename.endswith(sample_image_name): + index = i + break +if index is None: + raise AttributeError("Cannot find sample image: %s" % sample_image_name) image_data = sample_images.images[index] # Load Image and transform to a 2D numpy array. @@ -36,13 +33,13 @@ # Take a sample of the data. sample_indices = range(len(image_array)) np.random.shuffle(sample_indices) -sample_indices = sample_indices[:int(len(image_array) * 0.5)] +sample_indices = sample_indices[:int(len(image_array) * 0.2)] sample_data = image_array[sample_indices] # Perform Vector Quantisation with 256 clusters. k = 256 kmeans = KMeans(k=k) -kmeans.fit(image_array) +kmeans.fit(sample_data) # Get labels for all points labels = kmeans.predict(image_array) # Save the reduced dataset. Only the centroids and labels need to be saved. @@ -56,8 +53,9 @@ def recreate_image(centroids, labels, w, h): for i in range(w): for j in range(h): image[i][j] = centroids[labels[label_num]] - print labels[label_num], label_num label_num += 1 + print np.histogram(labels) + print set(labels) return image # Display all results, alongside original image @@ -67,5 +65,5 @@ def recreate_image(centroids, labels, w, h): centroids, labels = reduced_image im = pl.imshow(recreate_image(centroids, labels, w, h)) -show() +pl.show() diff --git a/scikits/learn/datasets/base.py b/scikits/learn/datasets/base.py index b5e552ad58abb..6077374cc0917 100644 --- a/scikits/learn/datasets/base.py +++ b/scikits/learn/datasets/base.py @@ -7,6 +7,7 @@ # 2010 Olivier Grisel # License: Simplified BSD +import os import csv import shutil import textwrap @@ -372,9 +373,20 @@ def load_sample_images(): >>> # pl.matshow(images.images[0]) # Visualize the first image >>> # pl.show() """ + # Try to import Image and imresize from PIL. We do this here to prevent + # this module from depending on PIL. + try: + try: + from scipy.misc import Image + except ImportError: + from scipy.misc.pilutil import Image + except ImportError: + raise ImportError("The Python Imaging Library (PIL)" + "is required to load data from jpeg files") module_path = join(dirname(__file__), "images") descr = open(join(module_path, 'README.txt')).read() - filenames = [filename for filename in os.listdir(module_path) + filenames = [join(module_path, filename) + for filename in os.listdir(module_path) if filename.endswith(".jpg")] # Load image data for each image in the source folder. images = [np.asarray(Image.open(filename)) diff --git a/scikits/learn/datasets/setup.py b/scikits/learn/datasets/setup.py index bdf08f0d6167f..8673f9862e620 100644 --- a/scikits/learn/datasets/setup.py +++ b/scikits/learn/datasets/setup.py @@ -7,6 +7,7 @@ def configuration(parent_package='', top_path=None): config = Configuration('datasets', parent_package, top_path) config.add_data_dir('data') config.add_data_dir('descr') + config.add_data_dir('images') config.add_extension('_svmlight_format',