[distances] Extract cdist in a function, see #183470
authorVincent Michel <vincent.michel@logilab.fr>
Tue, 15 Oct 2013 11:30:53 +0000
changeset 318 23b903af099d
parent 317 172be468a7aa
child 319 40955d08e971
[distances] Extract cdist in a function, see #183470
distances.py
--- a/distances.py	Tue Oct 15 09:33:38 2013 +0000
+++ b/distances.py	Tue Oct 15 11:30:53 2013 +0000
@@ -30,6 +30,36 @@
 ###############################################################################
 ### UTILITY FUNCTIONS #########################################################
 ###############################################################################
+def cdist(distance_callback, refset, targetset, matrix_normalized=False,
+          ref_indexes=None, target_indexes=None):
+    """ Compute the metric matrix, given two datasets and a metric
+
+    Parameters
+    ----------
+    refset: a dataset (list of records)
+
+    targetset: a dataset (list of records)
+
+    Returns
+    -------
+
+    A distance matrix, of shape (len(refset), len(targetset))
+    with the distance of each element in it.
+    """
+    ref_indexes = ref_indexes or xrange(len(refset))
+    target_indexes = target_indexes or xrange(len(targetset))
+    distmatrix = empty((len(ref_indexes), len(target_indexes)), dtype='float32')
+    size = distmatrix.shape
+    for i, iref in enumerate(ref_indexes):
+        for j, jref in enumerate(target_indexes):
+            d = 1
+            if refset[iref] and targetset[jref]:
+                d = distance_callback(refset[iref], targetset[jref])
+                if matrix_normalized:
+                    d = 1 - (1.0/(1.0 + d))
+            distmatrix[i, j] = d
+    return distmatrix
+
 def _handlespaces(stra, strb, distance, tokenizer=None, **kwargs):
     """ Compute the matrix of distances between all tokens of stra and strb
         (with function ``distance``). Extra args are given to the distance
@@ -364,19 +394,9 @@
         A distance matrix, of shape (len(refset), len(targetset))
         with the distance of each element in it.
         """
-        ref_indexes = ref_indexes or xrange(len(refset))
-        target_indexes = target_indexes or xrange(len(targetset))
-        distmatrix = empty((len(ref_indexes), len(target_indexes)), dtype='float32')
-        size = distmatrix.shape
-        for i, iref in enumerate(ref_indexes):
-            for j, jref in enumerate(target_indexes):
-                d = 1
-                if refset[iref] and targetset[jref]:
-                    d = self.distance(refset[iref], targetset[jref])
-                    if self.matrix_normalized:
-                        d = 1 - (1.0/(1.0 + d))
-                distmatrix[i, j] = d
-        return distmatrix
+        return cdist(self.distance, refset, targetset,
+                     matrix_normalized=self.matrix_normalized,
+                     ref_indexes=ref_indexes, target_indexes=target_indexes)
 
     def pdist(self, dataset):
         """ Compute the upper triangular matrix in a way similar