[Minhashing] : Really, really, really faster training
authorSimon Chabot <simon.chabot@logilab.fr>
Wed, 24 Oct 2012 15:39:38 +0200
changeset 44 905240fb0f8d
parent 43 a05bc17fa163
child 45 9f4669700221
[Minhashing] : Really, really, really faster training
minhashing.py
--- a/minhashing.py	Wed Oct 24 14:51:09 2012 +0200
+++ b/minhashing.py	Wed Oct 24 15:39:38 2012 +0200
@@ -46,25 +46,23 @@
         self._trained = False
         self.sigmatrix = None
 
-    def train(self, sentences, k = 2, siglen = 200, dispTime = False):
+    def train(self, sentences, k = 2, siglen = 200):
         """ Train the minlsh on the given sentences.
 
             - `k` is the length of the k-wordgrams used
               (the lower k is, the faster is the training)
             - `siglen` the length of the sentences signature
-            - `dispTime` is used for whether the left time should be displayed
-              or not.
 
         """
 
-        matrixdocument = self._buildmatrixdocument(sentences, k, dispTime)
-        if dispTime: print "Training is done. Wait while signaturing"
+        matrixdocument = self._buildmatrixdocument(sentences, k)
+        print "Training is done. Wait while signaturing"
 
         self.sigmatrix = self._signaturematrix(matrixdocument, siglen)
         self._trained = True
 
 
-    def _buildmatrixdocument(self, sentences, k, dispTime):
+    def _buildmatrixdocument(self, sentences, k):
         """ Return a sparse matrix where :
 
             - Each sentence is a column
@@ -75,29 +73,26 @@
 
         """
 
-        sets = []
-        universe = set()
+        rows = []
+        data = []
+        universe = {}
+        sizeofuniverse = 0
         for sent in sentences:
-            sets.append([w for w in wordgrams(sent, k)])
-        universe = universe.union(*sets)
-        matrixdoc = lil_matrix((len(universe), len(sets)))
+            row = []
+            rowdata = []
+            for w in wordgrams(sent, k):
+                row.append(universe.setdefault(w, sizeofuniverse))
+                if row[-1] == sizeofuniverse:
+                    sizeofuniverse += 1
+                rowdata.append(1)
+            rows.append(row)
+            data.append(rowdata)
 
-        univlen = len(universe)
-        if dispTime:
-            prev, cur = None, time()
-        for inde, elt in enumerate(universe):
-            if dispTime and inde and inde % 200 == 0:
-                prev, cur = cur, time()
-                dt = cur - prev
-                timeleft = int(((univlen - inde) * dt) / 200)
-                print "Time left : %(min)d min %(sec)d s" % {
-                          'min' : timeleft / 60,
-                          'sec' : timeleft - 60 * (timeleft / 60)
-                        }
-            for inds, curset in enumerate(sets):
-                matrixdoc[inde, inds] = int(elt in curset)
+        matrixdoc = lil_matrix((len(rows), sizeofuniverse))
+        matrixdoc.rows = rows
+        matrixdoc.data = data
 
-        return matrixdoc
+        return matrixdoc.T
 
     def _signaturematrix(self, matrixdocument, siglen):
         """ Return a matrix where each column is the signature the document
@@ -134,17 +129,18 @@
         buckets = defaultdict(set)
 
         nbbands = int(self.sigmatrix.shape[0] / bandsize)
-        print "threshold is %.3f" % pow(1./nbbands, 1./bandsize)
+        if dispThreshold:
+            print "threshold is %.3f" % pow(1./nbbands, 1./bandsize)
+
         for r in xrange(0, self.sigmatrix.shape[0], bandsize):
             for i in xrange(len(col)):
                 stri = ''.join(str(val) for val in col[i][r:r+bandsize])
                 buckets[hash(stri)].add(i)
 
-        if sentenceid < 0 or sentenceid >= self.sigmatrix.shape[0]:
-            return set(tuple(v) for v in buckets.values() if len(v) > 1)
-
-        return set(tuple(v) for v in buckets.values()
-                   if len(v) > 1 and sentenceid in v)
+        if 0 <= sentenceid < self.sigmatrix.shape[1]:
+            return set(tuple(v) for v in buckets.values()
+                       if len(v) > 1 and sentenceid in v)
+        return set(tuple(v) for v in buckets.values() if len(v) > 1)
 
 if __name__ == '__main__':
     from cubes.alignment.normalize import (loadlemmas, simplify)