[blocking] Add a merge blocking, see #182023.
authorVincent Michel <vincent.michel@logilab.fr>
Tue, 08 Oct 2013 13:25:07 +0000
changeset 311 6bb993dec00b
parent 310 6e2933017fca
child 312 74d0a106fe23
[blocking] Add a merge blocking, see #182023.
blocking.py
test/test_blocking.py
--- a/blocking.py	Tue Oct 15 09:04:34 2013 +0000
+++ b/blocking.py	Tue Oct 08 13:25:07 2013 +0000
@@ -364,6 +364,77 @@
 
 
 ###############################################################################
+### MERGE BLOCKING ############################################################
+###############################################################################
+class MergeBlocking(BaseBlocking):
+    """ This blocking technique keep only one appearance of one given values,
+    and removes all the other records having this value.
+    The merge is based on a score function
+
+    E.g.
+      ('http://fr.wikipedia.org/wiki/Paris_%28Texas%29', 'Paris', 25898)
+      ('http://fr.wikipedia.org/wiki/Paris', 'Paris', 12223100)
+
+    could be (with a score function based on the population (third value):
+
+      ('http://fr.wikipedia.org/wiki/Paris', 'Paris', 12223100)
+
+    !!! WARNING !!! This is only done on ONE set (the one with a non null attr index)
+    """
+
+    def __init__(self, ref_attr_index, target_attr_index, score_func):
+        super(MergeBlocking, self).__init__(ref_attr_index, target_attr_index)
+        self.score_func = score_func
+        self.merged_dataset = None
+        self.other_dataset = None
+        if ref_attr_index is None and target_attr_index is None:
+            raise ValueError('At least one of ref_attr_index or target_attr_index '
+                             'should not be None')
+
+    def _fit(self, refset, targetset):
+        """ Fit a dataset in an index using the callback
+        """
+        if self.ref_attr_index is not None:
+            # Merge refset
+            self.merged_dataset = self._merge_dataset(refset, self.ref_attr_index)
+            self.other_dataset = [(ind, r[0]) for ind, r in enumerate(targetset)]
+        else:
+            # Merge targetset
+            self.merged_dataset = self._merge_dataset(targetset, self.target_attr_index)
+            self.other_dataset = [(ind, r[0]) for ind, r in enumerate(refset)]
+
+    def _merge_dataset(self, dataset, attr_index):
+        """ Merge a dataset
+        """
+        merged_dataset_dict = {}
+        for ind, record in enumerate(dataset):
+            score = self.score_func(record)
+            if record[attr_index] not in merged_dataset_dict:
+                # Create new entry
+                merged_dataset_dict[record[attr_index]] = (ind, record, score)
+            elif (record[attr_index] in merged_dataset_dict
+                  and merged_dataset_dict[record[attr_index]][2] < score):
+                # Change current score
+                merged_dataset_dict[record[attr_index]] = (ind, record, score)
+        return [(ind, r[0]) for ind, r, score in merged_dataset_dict.itervalues()]
+
+    def _iter_blocks(self):
+        """ Iterator over the different possible blocks.
+        """
+        if self.ref_attr_index is not None:
+            yield self.merged_dataset, self.other_dataset
+        else:
+            # self.target_attr_index is not None
+            yield self.other_dataset, self.merged_dataset
+
+    def _cleanup(self):
+        """ Cleanup blocking for further use (e.g. in pipeline)
+        """
+        self.merged_dataset = None
+        self.other_dataset = None
+
+
+###############################################################################
 ### CLUSTERING-BASED BLOCKINGS ################################################
 ###############################################################################
 class KmeansBlocking(BaseBlocking):
--- a/test/test_blocking.py	Tue Oct 15 09:04:34 2013 +0000
+++ b/test/test_blocking.py	Tue Oct 08 13:25:07 2013 +0000
@@ -24,6 +24,7 @@
 from nazca.distances import (levenshtein, soundex, soundexcode,   \
                              jaccard, euclidean, geographical)
 from nazca.blocking import (KeyBlocking, SortedNeighborhoodBlocking,
+                            MergeBlocking,
                             NGramBlocking, PipelineBlocking,
                             SoundexBlocking, KmeansBlocking,
                             MinHashingBlocking, KdTreeBlocking)
@@ -227,6 +228,46 @@
             self.assertIn(block, blocks)
 
 
+class MergeBlockingTest(unittest2.TestCase):
+
+
+    def test_merge_blocks(self):
+        blocking = MergeBlocking(ref_attr_index=1, target_attr_index=None,
+                                 score_func=lambda x:x[2])
+        refset = [('http://fr.wikipedia.org/wiki/Paris_%28Texas%29', 'Paris', 25898),
+                  ('http://fr.wikipedia.org/wiki/Paris', 'Paris', 12223100),
+                  ('http://fr.wikipedia.org/wiki/Saint-Malo', 'Saint-Malo', 46342)]
+        targetset = [('Paris (Texas)', 25000),
+                     ('Paris (France)', 12000000)]
+        true_blocks = [(['http://fr.wikipedia.org/wiki/Paris',
+                         'http://fr.wikipedia.org/wiki/Saint-Malo'],
+                        ['Paris (Texas)', 'Paris (France)'])]
+        blocking.fit(refset, targetset)
+        blocks = list(blocking.iter_id_blocks())
+        self.assertEqual(len(blocks), len(true_blocks))
+        self.assertEqual(len(blocks), len(true_blocks))
+        for block in true_blocks:
+            self.assertIn(block, blocks)
+
+    def test_merge_blocks_targetset(self):
+        blocking = MergeBlocking(ref_attr_index=None, target_attr_index=2,
+                                 score_func=lambda x:x[1])
+        refset = [('Paris (Texas)', 25000),
+                  ('Paris (France)', 12000000)]
+        targetset = [('http://fr.wikipedia.org/wiki/Paris_%28Texas%29', 25898, 'Paris'),
+                     ('http://fr.wikipedia.org/wiki/Paris', 12223100, 'Paris'),
+                     ('http://fr.wikipedia.org/wiki/Saint-Malo', 46342, 'Saint-Malo')]
+        true_blocks = [(['Paris (Texas)', 'Paris (France)'],
+                        ['http://fr.wikipedia.org/wiki/Paris',
+                         'http://fr.wikipedia.org/wiki/Saint-Malo'])]
+        blocking.fit(refset, targetset)
+        blocks = list(blocking.iter_id_blocks())
+        self.assertEqual(len(blocks), len(true_blocks))
+        self.assertEqual(len(blocks), len(true_blocks))
+        for block in true_blocks:
+            self.assertIn(block, blocks)
+
+
 class KmeansBlockingTest(unittest2.TestCase):
 
     def test_clustering_blocking_kmeans(self):