author Simon Chabot Fri, 19 Oct 2012 18:22:14 +0200 changeset 35 bd2e8b9ba8bd parent 34 6dae1ef7ecdf child 36 ec24394673eb
 matrix.py file | annotate | diff | comparison | revisions test/test_alignment.py file | annotate | diff | comparison | revisions
```--- a/matrix.py	Fri Oct 19 17:27:40 2012 +0200
+++ b/matrix.py	Fri Oct 19 18:22:14 2012 +0200
@@ -20,6 +20,7 @@
from collections import defaultdict
from scipy.sparse import lil_matrix
from scipy import where
+from copy import deepcopy

class Distancematrix(object):
""" Construct and compute a matrix of distance given a distance function.
@@ -57,6 +58,52 @@
def __repr__(self):
return self._matrix.todense().__repr__()

+    def __rmul__(self, number):
+        return self * number
+
+    def __mul__(self, number):
+        if not (isinstance(number, int) or isinstance(number, float)):
+            raise NotImplementedError
+
+        other = deepcopy(self)
+        other._matrix *= number
+        other._maxdist *= number
+        return other
+
+        if not isinstance(other, Distancematrix):
+            raise NotImplementedError
+
+        result = deepcopy(self)
+        result._maxdist = self._maxdist + other._maxdist
+        result._matrix = (self._matrix + other._matrix).tolil()
+        return result
+
+    def __sub__(self, other):
+        if not isinstance(other, Distancematrix):
+            raise NotImplementedError
+
+        result = deepcopy(self)
+        result._maxdist = self._maxdist - other._maxdist
+        result._matrix = (self._matrix - other._matrix).tolil()
+        return result
+
+    def __eq__(self, other):
+        if not isinstance(other, Distancematrix):
+            return False
+
+        if (self._matrix.rows != other._matrix.rows).any():
+            return False
+
+        if (self._matrix.data != other._matrix.data).any():
+            return False
+
+        if self.distance != other.distance:
+            return False
+
+        return True
+
+
def matched(self, cutoff = 0, normalized = False):
match = defaultdict(list)
if normalized:```
```--- a/test/test_alignment.py	Fri Oct 19 17:27:40 2012 +0200
+++ b/test/test_alignment.py	Fri Oct 19 18:22:14 2012 +0200
@@ -202,6 +202,12 @@
{0: [(0, d(i1[0], i2[0]))], 1: [(1, d(i1[1], i2[1])),
(2, d(i1[1], i2[2]))]})

+    def test_operation(self):
+        m = self.matrix
+        self.assertEqual(3 * m, m * 3)
+        self.assertEqual((m - 0.5*m), (0.5 * m))
+        self.assertEqual(m + 10*m - m * 3, 8 * m)
+

if __name__ == '__main__':
unittest2.main()```