Skip to content

Commit

Permalink
Merge pull request scikit-learn#3786 from AlexanderFabisch/tsne_fix
Browse files Browse the repository at this point in the history
Fix t-SNE with "non-squarable" metric
  • Loading branch information
larsmans committed Oct 21, 2014
2 parents d4405cb bf7993e commit a0e5fcb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
11 changes: 8 additions & 3 deletions sklearn/manifold/t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 333,7 @@ class TSNE(BaseEstimator):
Maximum number of iterations for the optimization. Should be at
least 200.

metric : string or callable, (default: "euclidean")
metric : string or callable, optional
The metric to use when calculating distance between instances in a
feature array. If metric is a string, it must be one of the options
allowed by scipy.spatial.distance.pdist for its metric parameter, or
Expand All @@ -342,7 342,8 @@ class TSNE(BaseEstimator):
Alternatively, if metric is a callable function, it is called on each
pair of instances (rows) and the resulting value recorded. The callable
should take two arrays from X as input and return a value indicating
the distance between them.
the distance between them. The default is "euclidean" which is
interpreted as squared euclidean distance.

init : string, optional (default: "random")
Initialization of embedding. Possible options are 'random' and 'pca'.
Expand Down Expand Up @@ -432,7 433,11 @@ def _fit(self, X):
else:
if self.verbose:
print("[t-SNE] Computing pairwise distances...")
distances = pairwise_distances(X, metric=self.metric, squared=True)

if self.metric == "euclidean":
distances = pairwise_distances(X, metric=self.metric, squared=True)
else:
distances = pairwise_distances(X, metric=self.metric)

# Degrees of freedom of the Student's t-distribution. The suggestion
# alpha = n_components - 1 comes from "Learning a Parametric Embedding
Expand Down
9 changes: 9 additions & 0 deletions sklearn/manifold/tests/test_t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 254,12 @@ def test_verbose():
assert("Finished" in out)
assert("early exaggeration" in out)
assert("Finished" in out)


def test_chebyshev_metric():
"""t-SNE should allow metrics that cannot be squared (issue #3526)."""
random_state = check_random_state(0)
tsne = TSNE(verbose=2, metric="chebyshev")
X = random_state.randn(5, 2)
tsne.fit_transform(X)

0 comments on commit a0e5fcb

Please sign in to comment.