[blocking] Add compatibility for older version of sklearn, closes #252733
authorVincent Michel <vincent.michel@logilab.fr>
Mon, 26 May 2014 17:21:21 +0200
changeset 440 56114a122a56
parent 439 b5371dba546e
child 441 99863e609845
[blocking] Add compatibility for older version of sklearn, closes #252733
rl/blocking.py
--- a/rl/blocking.py	Fri May 23 15:16:33 2014 +0200
+++ b/rl/blocking.py	Mon May 26 17:21:21 2014 +0200
@@ -457,7 +457,11 @@
         idelement = tuple([0 for _ in xrange(len(refset[0][self.ref_attr_index]))])
         # We assume here that there are at least 2 elements in the refset
         n_clusters = self.n_clusters or (len(refset)/10 or len(refset)/2)
-        kmeans =  self.cluster_class(n_clusters=n_clusters)
+        try:
+            kmeans =  self.cluster_class(n_clusters=n_clusters)
+        except TypeError:
+            # Try older API version of sklearn
+            kmeans =  self.cluster_class(k=n_clusters)
         kmeans.fit([elt[self.ref_attr_index] or idelement for elt in refset])
         self.kmeans = kmeans
         # Predict on targetset
@@ -474,7 +478,8 @@
                           and containts the indexes of the record in the
                           corresponding dataset.
         """
-        neighbours = [[[], []] for _ in xrange(self.kmeans.n_clusters)]
+        n_clusters = self.kmeans.n_clusters if hasattr(self.kmeans, 'n_clusters') else self.kmeans.k
+        neighbours = [[[], []] for _ in xrange(n_clusters)]
         for ind, li in enumerate(self.predicted):
             neighbours[li][1].append(self.targetids[ind])
         for ind, li in enumerate(self.kmeans.labels_):