[Matrix] Add basic operations such as add, mul sub, etc
authorSimon Chabot <simon.chabot@logilab.fr>
Fri, 19 Oct 2012 18:22:14 +0200
changeset 35 bd2e8b9ba8bd
parent 34 6dae1ef7ecdf
child 36 ec24394673eb
[Matrix] Add basic operations such as add, mul sub, etc
matrix.py
test/test_alignment.py
--- a/matrix.py	Fri Oct 19 17:27:40 2012 +0200
+++ b/matrix.py	Fri Oct 19 18:22:14 2012 +0200
@@ -20,6 +20,7 @@
 from collections import defaultdict
 from scipy.sparse import lil_matrix
 from scipy import where
+from copy import deepcopy
 
 class Distancematrix(object):
     """ Construct and compute a matrix of distance given a distance function.
@@ -57,6 +58,52 @@
     def __repr__(self):
         return self._matrix.todense().__repr__()
 
+    def __rmul__(self, number):
+        return self * number
+
+    def __mul__(self, number):
+        if not (isinstance(number, int) or isinstance(number, float)):
+            raise NotImplementedError
+
+        other = deepcopy(self)
+        other._matrix *= number
+        other._maxdist *= number
+        return other
+
+    def __add__(self, other):
+        if not isinstance(other, Distancematrix):
+            raise NotImplementedError
+
+        result = deepcopy(self)
+        result._maxdist = self._maxdist + other._maxdist
+        result._matrix = (self._matrix + other._matrix).tolil()
+        return result
+
+    def __sub__(self, other):
+        if not isinstance(other, Distancematrix):
+            raise NotImplementedError
+
+        result = deepcopy(self)
+        result._maxdist = self._maxdist - other._maxdist
+        result._matrix = (self._matrix - other._matrix).tolil()
+        return result
+
+    def __eq__(self, other):
+        if not isinstance(other, Distancematrix):
+            return False
+
+        if (self._matrix.rows != other._matrix.rows).any():
+            return False
+
+        if (self._matrix.data != other._matrix.data).any():
+            return False
+
+        if self.distance != other.distance:
+            return False
+
+        return True
+
+
     def matched(self, cutoff = 0, normalized = False):
         match = defaultdict(list)
         if normalized:
--- a/test/test_alignment.py	Fri Oct 19 17:27:40 2012 +0200
+++ b/test/test_alignment.py	Fri Oct 19 18:22:14 2012 +0200
@@ -202,6 +202,12 @@
                         {0: [(0, d(i1[0], i2[0]))], 1: [(1, d(i1[1], i2[1])), 
                                                        (2, d(i1[1], i2[2]))]})
 
+    def test_operation(self):
+        m = self.matrix
+        self.assertEqual(3 * m, m * 3)
+        self.assertEqual((m - 0.5*m), (0.5 * m))
+        self.assertEqual(m + 10*m - m * 3, 8 * m)
+
 
 if __name__ == '__main__':
     unittest2.main()