[aligner] Now, returned pairs also have the distance between the two elements, see #183463
authorVincent Michel <vincent.michel@logilab.fr>
Tue, 15 Oct 2013 09:33:38 +0000
changeset 317 172be468a7aa
parent 316 eecaebe54657
child 318 23b903af099d
[aligner] Now, returned pairs also have the distance between the two elements, see #183463
aligner.py
test/test_alignment.py
--- a/aligner.py	Tue Oct 15 09:32:21 2013 +0000
+++ b/aligner.py	Tue Oct 15 09:33:38 2013 +0000
@@ -35,13 +35,15 @@
             bestid, _ = sorted(global_matched[refid], key=lambda x:x[1])[0]
             ref_record = refset[refid]
             target_record = targetset[bestid]
-            yield (ref_record[0], refid), (target_record[0], bestid)
+            distance = global_mat[refid, bestid] if global_mat is not None else None
+            yield (ref_record[0], refid), (target_record[0], bestid), distance
     else:
         for refid in global_matched:
             for targetid, _ in global_matched[refid]:
                 ref_record = refset[refid]
                 target_record = targetset[targetid]
-                yield (ref_record[0], refid), (target_record[0], targetid)
+                distance = global_mat[refid, targetid] if global_mat is not None else None
+                yield (ref_record[0], refid), (target_record[0], targetid), distance
 
 
 ###############################################################################
--- a/test/test_alignment.py	Tue Oct 15 09:32:21 2013 +0000
+++ b/test/test_alignment.py	Tue Oct 15 09:33:38 2013 +0000
@@ -120,13 +120,15 @@
                                                      target_attr_index=2,
                                                      threshold=0.3))
         unimatched = list(aligner.get_aligned_pairs(refset, targetset, unique=True))
+        unimatched_wo_distance = [r[:2] for r in unimatched]
         matched = list(aligner.get_aligned_pairs(refset, targetset, unique=False))
+        matched_wo_distance = [r[:2] for r in matched]
         self.assertEqual(len(matched), len(all_matched))
         for m in all_matched:
-            self.assertIn(m, matched)
+            self.assertIn(m, matched_wo_distance)
         self.assertEqual(len(unimatched), len(uniq_matched))
         for m in uniq_matched:
-            self.assertIn(m, unimatched)
+            self.assertIn(m, unimatched_wo_distance)
 
     def test_align_from_file(self):
         uniq_matched = [(('V1', 0), ('T1', 0)), (('V2', 1), ('T3', 2)), (('V4', 3), ('T2', 1))]
@@ -141,9 +143,10 @@
                                                                       'targetfile.csv'),
                                                             ref_indexes=[0, 1, (2, 3)],
                                                             target_indexes=[0, 1, (2, 3)],))
+        matched_wo_distance = [r[:2] for r in matched]
         self.assertEqual(len(matched), len(uniq_matched))
         for m in uniq_matched:
-            self.assertIn(m, matched)
+            self.assertIn(m, matched_wo_distance)
 
 
 class PipelineAlignerTestCase(unittest2.TestCase):
@@ -168,9 +171,10 @@
         uniq_matched = [(('V1', 0), ('T1', 0)), (('V2', 1), ('T3', 2)),
                         (('V4', 3), ('T2', 1)), (('V3', 2), ('T4', 3))]
         matched = list(pipeline.get_aligned_pairs(refset, targetset, unique=True))
+        matched_wo_distance = [r[:2] for r in matched]
         self.assertEqual(len(matched), len(uniq_matched))
         for m in uniq_matched:
-            self.assertIn(m, matched)
+            self.assertIn(m, matched_wo_distance)