[aligner] Add tools to align from files, see #182030
authorVincent Michel <vincent.michel@logilab.fr>
Tue, 08 Oct 2013 13:18:13 +0000
changeset 304 3ad94b0b5322
parent 303 52ba47aa143f
child 305 afe25ae5facf
[aligner] Add tools to align from files, see #182030
aligner.py
test/test_alignment.py
--- a/aligner.py	Tue Oct 15 08:10:03 2013 +0000
+++ b/aligner.py	Tue Oct 08 13:18:13 2013 +0000
@@ -20,18 +20,22 @@
 from scipy import zeros
 from scipy.sparse import lil_matrix
 
+from nazca.dataio import parsefile
+
 
 ###############################################################################
 ### BASE ALIGNER OBJECT #######################################################
 ###############################################################################
 class BaseAligner(object):
 
-    def __init__(self, threshold, processings):
+    def __init__(self, threshold, processings, verbose=False):
         self.threshold = threshold
         self.processings = processings
+        self.verbose = verbose
         self.ref_normalizer = None
         self.target_normalizer = None
         self.blocking = None
+        self.nb_comparisons = 0
 
     def register_ref_normalizer(self, normalizer):
         """ Register normalizers to be applied
@@ -111,38 +115,120 @@
         for refblock, targetblock in self.blocking.iter_blocks():
             ref_index = [r[0] for r in refblock]
             target_index = [r[0] for r in targetblock]
-            print ref_index, target_index
+            self.nb_comparisons += len(ref_index)*len(target_index)
+            if self.verbose:
+                print 'Blocking: %s reference ids, %s target ids' % (len(ref_index),
+                                                                     len(target_index))
+                print 'Reference records :'
+                for ind in ref_index:
+                    print '\t--->', refset[ind]
+                print 'Target records :'
+                for ind in target_index:
+                    print '\t--->', targetset[ind]
             _, matched = self._get_match(refset, targetset, ref_index, target_index)
+            if self.verbose:
+                print 'Matched: %s / Total comparisons %s' % (len(matched), self.nb_comparisons)
             for k, values in matched.iteritems():
                 subdict = global_matched.setdefault(k, set())
                 for v, d in values:
                     subdict.add((v, d))
-                    # XXX avoid issue in sparse matrix
                     if get_matrix:
+                        # XXX avoid issue in sparse matrix
                         global_mat[k, v] = d or 10**(-10)
         return global_mat, global_matched
 
+    def _iter_aligned_pairs(self, refset, targetset, global_mat, global_matched, unique=True):
+        """ Return the aligned pairs
+        """
+        if unique:
+            for refid in global_matched:
+                bestid, _ = sorted(global_matched[refid], key=lambda x:x[1])[0]
+                ref_record = refset[refid]
+                target_record = targetset[bestid]
+                if self.verbose:
+                    print '\t\t', ref_record, ' <--> ', target_record
+                yield (ref_record[0], refid), (target_record[0], bestid)
+        else:
+            for refid in global_matched:
+                for targetid, _ in global_matched[refid]:
+                    ref_record = refset[refid]
+                    target_record = targetset[targetid]
+                    if self.verbose:
+                        print '\t\t', ref_record, ' <--> ', target_record
+                    yield (ref_record[0], refid), (target_record[0], targetid)
+        print 'Total comparisons : ', self.nb_comparisons
+
     def get_aligned_pairs(self, refset, targetset, unique=True):
         """ Get the pairs of aligned elements
         """
-        global_mat, global_matched = self.align(refset, targetset, False)
-        if unique:
-            for refid in global_matched:
-                bestid, _ = sorted(global_matched[refid], key=lambda x:x[1])[0]
-                yield refset[refid][0], targetset[bestid][0]
-        else:
-            for refid in global_matched:
-                for targetid, _ in global_matched[refid]:
-                    yield refset[refid][0], targetset[targetid][0]
+        global_mat, global_matched = self.align(refset, targetset, get_matrix=False)
+        for pair in self._iter_aligned_pairs(refset, targetset, global_mat, global_matched, unique):
+            yield pair
+
+    def align_from_files(self, reffile, targetfile,
+                         ref_indexes=None, target_indexes=None,
+                         ref_encoding=None, target_encoding=None,
+                         ref_separator='\t', target_separator='\t',
+                         get_matrix=True):
+        """ Align data from files
+
+        Parameters
+        ----------
+
+        reffile: name of the reference file
+
+        targetfile: name of the target file
+
+        ref_encoding: if given (e.g. 'utf-8' or 'latin-1'), it will
+                      be used to read the files.
+
+        target_encoding: if given (e.g. 'utf-8' or 'latin-1'), it will
+                         be used to read the files.
+
+        ref_separator: separator of the reference file
+
+        target_separator: separator of the target file
+        """
+        refset = parsefile(reffile, indexes=ref_indexes,
+                           encoding=ref_encoding, delimiter=ref_separator)
+        targetset = parsefile(targetfile, indexes=target_indexes,
+                              encoding=target_encoding, delimiter=target_separator)
+        return self.align(refset, targetset, get_matrix=get_matrix)
+
+    def get_aligned_pairs_from_files(self, reffile, targetfile,
+                         ref_indexes=None, target_indexes=None,
+                         ref_encoding=None, target_encoding=None,
+                         ref_separator='\t', target_separator='\t',
+                         unique=True):
+        """ Get the pairs of aligned elements
+        """
+        refset = parsefile(reffile, indexes=ref_indexes,
+                           encoding=ref_encoding, delimiter=ref_separator)
+        targetset = parsefile(targetfile, indexes=target_indexes,
+                              encoding=target_encoding, delimiter=target_separator)
+        global_mat, global_matched = self.align(refset, targetset, get_matrix=False)
+        for pair in self._iter_aligned_pairs(refset, targetset, global_mat, global_matched, unique):
+            yield pair
 
 
-## ###############################################################################
-## ### ITERATIVE ALIGNER OBJECT ##################################################
-## ###############################################################################
-## class MultiPassAligner(object):
-##     """ This aligner may be used to perform multi pass of alignements.
-##     Records linked in a previous pass will not be consider in the nex pass.
-##     """
+###############################################################################
+### ITERATIVE ALIGNER OBJECT ##################################################
+###############################################################################
+class IterativePassAligner(object):
+    """ This aligner may be used to perform multi pass of alignements.
+
+        It takes your csv files as arguments and split them into smaller ones
+        (files of `size` lines), and runs the alignment on those files.
 
-##     def __init__(self, threshold, treatments):
- 
+        If the distance of an alignment is below `equality_threshold`, the
+        alignment is considered as perfect, and the corresponding item is
+        removed from the alignset (to speed up the computation).
+    """
+
+    def __init__(self, threshold, treatments, equality_threshold):
+        self.threshold = threshold
+        self.treatments = treatments
+        self.equality_threshold = equality_threshold
+        self.ref_normalizer = None
+        self.target_normalizer = None
+        self.blocking = None
--- a/test/test_alignment.py	Tue Oct 15 08:10:03 2013 +0000
+++ b/test/test_alignment.py	Tue Oct 08 13:18:13 2013 +0000
@@ -51,7 +51,7 @@
             for v, distance in values:
                 self.assertIn((k,v), true_matched)
 
-    def test_neighbours_align(self):
+    def test_blocking_align(self):
         refset = [['V1', 'label1', (6.14194444444, 48.67)],
                   ['V2', 'label2', (6.2, 49)],
                   ['V3', 'label3', (5.1, 48)],
@@ -70,14 +70,14 @@
                                       threshold=0.3)
         blocking.fit(refset, targetset)
         predict_matched = set()
-        for alignind, targetind in blocking.iter_blocks():
+        for alignind, targetind in blocking.iter_indice_blocks():
             mat, matched = aligner._get_match(refset, targetset, alignind, targetind)
             for k, values in matched.iteritems():
                 for v, distance in values:
                     predict_matched.add((k, v))
         self.assertEqual(true_matched, predict_matched)
 
-    def test_divide_and_conquer_align(self):
+    def test_blocking_align_2(self):
         refset = [['V1', 'label1', (6.14194444444, 48.67)],
                   ['V2', 'label2', (6.2, 49)],
                   ['V3', 'label3', (5.1, 48)],
@@ -101,7 +101,7 @@
                 predict_matched.add((k, v))
         self.assertEqual(true_matched, predict_matched)
 
-    def test_alignall(self):
+    def test_unique_align(self):
         refset = [['V1', 'label1', (6.14194444444, 48.67)],
                     ['V2', 'label2', (6.2, 49)],
                     ['V3', 'label3', (5.1, 48)],
@@ -111,8 +111,9 @@
                      ['T2', 'labelt2', (5.3, 48.2)],
                      ['T3', 'labelt3', (6.25, 48.91)],
                      ]
-        all_matched = [('V1','T1'), ('V1', 'T3'), ('V2','T3'), ('V4','T2')]
-        uniq_matched = [('V2', 'T3'), ('V4', 'T2'), ('V1', 'T1')]
+        all_matched = [(('V1', 0), ('T3', 2)), (('V1', 0), ('T1', 0)),
+                       (('V2', 1), ('T3', 2)), (('V4', 3), ('T2', 1))]
+        uniq_matched = [(('V1', 0), ('T1', 0)), (('V2', 1), ('T3', 2)), (('V4', 3), ('T2', 1))]
         processings = (GeographicalProcessing(2, 2, units='km'),)
         aligner = alig.BaseAligner(threshold=30, processings=processings)
         aligner.register_blocking(blo.KdTreeBlocking(ref_attr_index=2,
@@ -127,6 +128,22 @@
         for m in uniq_matched:
             self.assertIn(m, unimatched)
 
+    def test_align_from_file(self):
+        uniq_matched = [(('V1', 0), ('T1', 0)), (('V2', 1), ('T3', 2)), (('V4', 3), ('T2', 1))]
+        processings = (GeographicalProcessing(2, 2, units='km'),)
+        aligner = alig.BaseAligner(threshold=30, processings=processings)
+        aligner.register_blocking(blo.KdTreeBlocking(ref_attr_index=2,
+                                                     target_attr_index=2,
+                                                     threshold=0.3))
+        matched = list(aligner.get_aligned_pairs_from_files(path.join(TESTDIR, 'data',
+                                                                      'alignfile.csv'),
+                                                            path.join(TESTDIR, 'data',
+                                                                      'targetfile.csv'),
+                                                            ref_indexes=[0, 1, (2, 3)],
+                                                            target_indexes=[0, 1, (2, 3)],))
+        self.assertEqual(len(matched), len(uniq_matched))
+        for m in uniq_matched:
+            self.assertIn(m, matched)
 
 if __name__ == '__main__':
     unittest2.main()