Skip to content

Commit

Permalink
idf_ setter for TfidfTransformer. Fixes scikit-learn#7102
Browse files Browse the repository at this point in the history
  • Loading branch information
serega committed Apr 1, 2018
1 parent 12cdb83 commit 639a6a0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
10 changes: 10 additions & 0 deletions sklearn/feature_extraction/tests/test_text.py
Expand Up @@ -942,6 +942,16 @@ def test_pickling_transformer():
orig.fit_transform(X).toarray())


def test_transformer_idf_setter():
X = CountVectorizer().fit_transform(JUNK_FOOD_DOCS)
orig = TfidfTransformer().fit(X)
copy = TfidfTransformer()
copy.idf_ = orig.idf_
assert_array_equal(
copy.fit_transform(X).toarray(),
orig.fit_transform(X).toarray())


def test_non_unique_vocab():
vocab = ['a', 'b', 'c', 'a', 'a']
vect = CountVectorizer(vocabulary=vocab)
Expand Down
13 changes: 13 additions & 0 deletions sklearn/feature_extraction/text.py
Expand Up @@ -1062,6 +1062,12 @@ class TfidfTransformer(BaseEstimator, TransformerMixin):
sublinear_tf : boolean, default=False
Apply sublinear tf scaling, i.e. replace tf with 1 + log(tf).
Attributes
----------
idf_ : array, shape = [n_features], or None
The learned idf vector (global term weights)
when ``use_idf`` is set to True, None otherwise.
References
----------
Expand Down Expand Up @@ -1157,6 +1163,13 @@ def idf_(self):
# which means hasattr(self, "idf_") is False
return np.ravel(self._idf_diag.sum(axis=0))

@idf_.setter
def idf_(self, value):
value = np.asarray(value, dtype=np.float64)
n_features = value.shape[0]
self._idf_diag = sp.spdiags(value, diags=0, m=n_features,
n=n_features, format='csr')


class TfidfVectorizer(CountVectorizer):
"""Convert a collection of raw documents to a matrix of TF-IDF features.
Expand Down

0 comments on commit 639a6a0

Please sign in to comment.