[distance] Now, it is possible to pass tuple to the distance processing to create a subrecord
authorVincent Michel <vincent.michel@logilab.fr>
Tue, 24 Jun 2014 23:05:19 +0000
changeset 451 db82a2d5cf88
parent 450 28ba90f7f947
child 453 7debe5d58a19
[distance] Now, it is possible to pass tuple to the distance processing to create a subrecord
test/test_distances.py
utils/distances.py
--- a/test/test_distances.py	Tue Jun 24 23:02:40 2014 +0000
+++ b/test/test_distances.py	Tue Jun 24 23:05:19 2014 +0000
@@ -139,11 +139,32 @@
 class GeographicalTestCase(unittest.TestCase):
 
     def test_geographical(self):
+        # Use the whole record
         processing = GeographicalProcessing(units='km')
         _input = ((48.856578, 2.351828), (51.504872, -0.07857))
         pdist = processing.pdist(_input)
         self.assertEqual([341.56415945105], pdist)
 
+    def test_geographical_2(self):
+        # Use a single column of the record (tuple version)
+        processing = GeographicalProcessing(ref_attr_index=1,
+                                            target_attr_index=1,
+                                            units='km')
+        _input = (('paris', (48.856578, 2.351828)),
+                  ('london', (51.504872, -0.07857)))
+        pdist = processing.pdist(_input)
+        self.assertEqual([341.56415945105], pdist)
+
+    def test_geographical_3(self):
+        # Use two columns of the record
+        processing = GeographicalProcessing(ref_attr_index=(1,2),
+                                            target_attr_index=(1,2),
+                                            units='km')
+        _input = (('paris', 48.856578, 2.351828),
+                  ('london', 51.504872, -0.07857))
+        pdist = processing.pdist(_input)
+        self.assertEqual([341.56415945105], pdist)
+
 
 class ExactMatchTestCase(unittest.TestCase):
 
--- a/utils/distances.py	Tue Jun 24 23:02:40 2014 +0000
+++ b/utils/distances.py	Tue Jun 24 23:05:19 2014 +0000
@@ -381,6 +381,14 @@
         self.weight = weight
         self.matrix_normalized = matrix_normalized
 
+    def build_record(self, record, index):
+        """ Allow to have ref_attr_index and target_attr_index to be couple
+        of index for (latitude, longitude) """
+        if isinstance(index, tuple) and len(index) == 2:
+            return (record[index[0]], record[index[1]])
+        else:
+            return (record[index] if index else record)
+
     def distance(self, reference_record, target_record):
         """ Compute the distance between two records
 
@@ -391,11 +399,8 @@
         target_record: a record (tuple/list of values) of the target dataset.
 
         """
-        refrecord = (reference_record[self.ref_attr_index] if self.ref_attr_index
-                     else reference_record)
-        targetrecord = (target_record[self.target_attr_index] if self.target_attr_index
-                        else target_record)
-        return self.distance_callback(refrecord, targetrecord)
+        return self.distance_callback(self.build_record(reference_record, self.ref_attr_index),
+                                      self.build_record(target_record, self.target_attr_index))
 
     def cdist(self, refset, targetset, ref_indexes=None, target_indexes=None):
         """ Compute the metric matrix, given two datasets and a metric
@@ -484,6 +489,7 @@
                                                     distance_callback,
                                                     weight, matrix_normalized)
 
+
 class SoundexProcessing(BaseProcessing):
     """ A processing based on the soundex distance.
     """