[aligner] Add kmeans to the available searchers list
authorSimon Chabot <simon.chabot@logilab.fr>
Thu, 08 Nov 2012 13:24:13 +0100
changeset 93 203140d3f9ed
parent 92 e01895cdf886
child 94 12123af1815e
[aligner] Add kmeans to the available searchers list
aligner.py
--- a/aligner.py	Thu Nov 08 11:41:01 2012 +0100
+++ b/aligner.py	Thu Nov 08 13:24:13 2012 +0100
@@ -39,16 +39,27 @@
 
 
 def findneighbours(alignset, targetset, indexes = (1, 1), mode = 'kdtree',
-                   threshold = 0.1, extraargs = {}):
+                   threshold = 0.1, k = 2, n_clusters = None):
+
+    SEARCHERS = set(['kdtree', 'minhashing', 'kmeans', 'minibatch'])
+    mode = mode.lower()
+
+    if mode not in SEARCHERS:
+        raise NotImplementedError('Unknown mode given')
+
+##### KDTree #######
     if mode == 'kdtree':
+        # XXX : If there are more than 2 dimensions ??
         aligntree  = KDTree([elt[indexes[0]] or (0, 0) for elt in alignset])
         targettree = KDTree([elt[indexes[1]] or (0, 0) for elt in targetset])
         return aligntree.query_ball_tree(targettree, threshold)
+
+#### Minhashing #####
     elif mode == 'minhashing':
         minhasher = Minlsh()
         minhasher.train([elt[indexes[0]] or '' for elt in alignset] +
                         [elt[indexes[1]] or '' for elt in targetset],
-                        **extraargs)
+                        k)
         rawneighbours = minhasher.findsimilarsentences(threshold)
         neighbours = [[] for _ in xrange(len(alignset))]
         for data in rawneighbours:
@@ -59,6 +70,27 @@
                                       for e in data if e >= len(alignset)])
         return neighbours
 
+#### Kmeans #####
+    elif mode in set(['kmeans', 'minbatch']):
+        from sklearn import cluster
+        if mode == 'kmeans':
+            kmeans = cluster.KMeans(n_clusters=n_clusters or (len(alignset)/100))
+        else:
+            kmeans = cluster.MiniBatchKMeans(n_clusters=n_clusters or (len(alignset)/100))
+        # XXX : If there are more than 2 dimensions ??
+        kmeans.fit([elt[indexes[0]] or (0, 0) for elt in alignset])
+        predicted = kmeans.predict([elt[indexes[1]] or (0, 0) for elt in targetset])
+
+        clusters = [[] for _ in xrange(kmeans.n_clusters)]
+        print kmeans.n_clusters
+        for ind, j in enumerate(predicted):
+            clusters[j].append(ind)
+        neighbours = []
+        labels = kmeans.labels_
+        for i in xrange(len(alignset)):
+            neighbours.append(clusters[labels[i]])
+        return neighbours
+
 def align(alignset, targetset, treatments, threshold, resultfile):
     """ Try to align the items of alignset onto targetset's ones