wip add diprelib draft
authorVincent Michel <vincent.michel@logilab.fr>
Mon, 30 Jun 2014 15:13:48 +0000
changeset 467 2d1a782af235
parent 466 a507ff7a2ced
wip add diprelib
diprelib.py
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/diprelib.py	Mon Jun 30 15:13:48 2014 +0000
@@ -0,0 +1,462 @@
+#-*- coding:utf8 -*-
+
+import re
+from nltk import RegexpParser
+from cPickle import load
+from nltk.tokenize import wordpunct_tokenize #SpaceTokenizer, WordPuncTokenizer
+
+def load_tagger(filename):
+    """ Load a tagger
+    """
+    fobj = open(filename, 'rb')
+    tagger = load(fobj)
+    fobj.close()
+    return tagger
+
+def unicode_tokenize(sentence):
+    """ Tokenize french unicode
+    """
+    #tokenize = WordPuncTokenizer() #SpaceTokenizer()
+    #return [w for w in tokenize.tokenize(sentence)]
+    return wordpunct_tokenize(sentence)
+
+def tag_sentence(sentence, tagger):
+    """ Tag a sentence
+    """
+    words = unicode_tokenize(sentence)
+    postags = tagger.tag(words)
+    if not words:
+        return postags
+    # Deal with capital at the beginning of a sentence
+    postag = tagger.tag((words[0].lower(),))
+    if postag[0][1] != None:
+        postags[0] = (words[0], postag[0][1])
+    return postags
+
+def longuest_common_suffix(words):
+    """ Return the longuest common suffix of words
+    """
+    suffix = ''
+    common = True
+    while common and len(suffix) < len(words[0]):
+        suffix = words[0][-(len(suffix) + 1)] + suffix
+        for word in words:
+            if not word.endswith(suffix):
+                common = False
+                break
+    return suffix if common else suffix[1:]
+
+def longuest_common_prefix(words):
+    """ Return the longuest common prefix of words
+    """
+    prefix = ''
+    common = True
+    while common and len(prefix) < len(words[0]):
+        prefix += words[0][len(prefix)]
+        for word in words:
+            if not word.startswith(prefix):
+                common = False
+                break
+    return prefix if common else prefix[:-1]
+
+def escape_regex(string):
+    """ Escape string for a regex application
+    """
+    caracteres = '\.?+*[]()^$'
+    for car in caracteres :
+        string = string.replace(car, '\\' + car)
+    return string
+
+def check_list(mylist, indexes_editable = list()):
+    """ List each element of ``mylist`` and ask the user for what to do with
+
+        indexes_editable provided, enables the user to edit the i-eme component of each
+        element. For instance, indexes_editable = [0, 3], allows the user to edit the
+        first and the fourth component of each element of mylist (!!! Doesn't
+        mean the first and fourth component of ``mylist`` are editable, but the
+        first and the fourth component of EACH element OF ``mylist``
+    """
+    checked_list = []
+    print 'Press : \n',
+    if indexes_editable:
+        print '\t+ [e]     to edit\n'
+    print '\t+ [d]     to delete\n'
+    print '\t+ [other] to keep\n'
+
+    for elt in mylist:
+        print elt,
+        choice = raw_input('> ').decode('utf-8').strip()
+        if choice == u'd':
+            continue
+
+        if choice == u'e':
+            for ind in indexes_editable :
+                if len(elt) >= ind:
+                    continue
+                new = raw_input('%s >>> ' % elt[ind].encode('utf-8'))
+                elt[ind] = new.decode('utf-8')
+        checked_list.append(elt)
+    return checked_list
+
+
+class Dipre():
+    """ Implementation of DIPRE (Dual Iteration Pattern Relation Expansion)
+        algorithm
+    """
+
+    def __init__(self, database, tagger, marker = None,
+                 database_encoding = 'utf-8', verbose = False):
+        """
+            + database : the corpus to read
+            + marker   : separator between two sentences (if None, one line =
+                         one sentence)
+            + tagger   : a nltk postagger
+        """
+        self.database = database
+        self.database_encoding = database_encoding
+        self.marker = marker
+        self.seeds = set([])
+        self.patterns = set([])
+        self.tagger = load_tagger(tagger)
+        self.verbose = verbose
+
+    def _vprint(self, *args):
+        """ Print the arguments if, and only if, verbose is on
+        """
+        if not self.verbose:
+            return
+        for arg in args:
+            print arg,
+        print
+
+    def set_seeds(self, seeds):
+        """ Initialize the seeds
+            Seeds is a set of seed.
+            A seed is built as :
+                ((word0, tag_regex0), (word1, tag_regex1))
+
+            tag_regex0 can be set to None,
+            meaning “everything is matched”
+
+            For instance :
+            seeds = set([
+                         ((u'Victor Hugo', '<NAM>+'), (u'les Misérables', None)),
+                       ])
+        """
+        self.seeds = set([])
+        for seed in seeds:
+            self.add_seed(seed)
+        if len(self.seeds) < 5:
+            self._vprint("The seed is set to : ", self.seeds)
+        else:
+            self._vprint("There are", len(self.seeds), "seeds now")
+
+    def add_seed(self, seed):
+        """ Add a seed to the current ones
+        """
+        #escape seeds before processing
+        ((word0, tag0), (word1, tag1)) = seed
+        word0 = escape_regex(word0)
+        word1 = escape_regex(word1)
+        self.seeds.add(((word0, tag0), (word1, tag1)))
+
+    def find_sentences(self, ind_max = 30000000, card_max = 300):
+        """ Read the database and return sentences where the words appear
+        """
+        fobj = open(self.database)
+        sents_seeds = []
+        fobj.readline() #Ignore the first line
+        card = 0
+        for ind, line in enumerate(fobj):
+            if ind_max and ind > ind_max or card > card_max:
+                break
+            lines = line.decode(self.database_encoding)
+            if self.marker:
+                lines = lines.split(self.marker)
+            else:
+                lines = [lines]
+            for sent in lines:
+                sent = sent.strip()
+                if not sent:
+                    continue
+                for seed in self.seeds:
+                    remaining = list(seed)
+                    for word, tags in seed:
+                        if re.search(word.lower(), sent.lower()):
+                            remaining.remove((word, tags))
+                    if not remaining: #ie: all words have been found
+                        sents_seeds.append((sent, seed))
+                        card += 1
+                        self._vprint(card, ')'+ sent.encode('utf-8') + '\n')
+                        break
+                    #else, try with an other seed
+        return sents_seeds
+
+    def build_occurrences(self, seed, sentence, data_len = 20):
+        """Return a list of occurrences built as this one :
+            [order, (word0, word1), (reg0, reg1), prefix, suffix, middle]
+
+            * data_len is used to limit the size of prefix and suffix.
+            * order : 0 means word0 is found /before/ word1
+                      1 otherwise
+
+            /!\ We assume that the list `words` contains only two words.
+                Others are simply ignored
+        """
+        occ = []
+        first, second = 0, 1
+        words = [w.lower() for (w, _) in seed]
+        regex_tag = [reg for (_, reg) in seed]
+        sentence_l = sentence.lower()
+        prefix = re.search(words[first], sentence_l)
+        suffix = re.search(words[second], sentence_l)
+
+        if prefix.start() > suffix.start():
+            first, second = second, first
+            prefix, suffix = suffix, prefix
+
+        pos_begin = max(0, prefix.start() - data_len)
+        pos_end = min(len(sentence), suffix.end() + data_len)
+        occ.append(first)
+
+
+        occ.append([words[first], words[second]])
+        occ.append([regex_tag[first], regex_tag[second]])
+
+        occ.append(sentence[pos_begin:prefix.start()])      #prefix
+        occ.append(sentence[suffix.end():pos_end])          #suffix
+        occ.append(sentence[prefix.end():suffix.start()])   #middle
+
+        return occ
+
+    def extract_patterns(self, occurrences):
+        """Return a list of patterns, built from a list of occurrences returned
+           by build_occurrences()
+        """
+        G = ({}, {}) #One dictionnay per direction
+        self.patterns = []
+        min_card = 2 #Miminal times a occ. must appear to be taken into account
+        for (direction, _, regex_tag, pre, suf, mid) in occurrences :
+            G[direction].setdefault(mid, []).append((pre, suf, regex_tag))
+        for direction in (0, 1):
+            for (mid, pre_suf_tag) in G[direction].items():
+                #Remove some noise
+                if len(pre_suf_tag) < min_card:
+                    G[direction].pop(mid)
+                    continue
+                #If the pattern looks interesting
+                common_pref = longuest_common_suffix([p for (p, _, _) in pre_suf_tag])
+                common_suff = longuest_common_prefix([s for (_, s, _) in pre_suf_tag])
+                regex_tag   = pre_suf_tag[0][2]
+
+                common_pref = escape_regex(common_pref).strip()
+                common_suff = escape_regex(common_suff).strip()
+                mid = escape_regex(mid).strip()
+
+                common_pref += '(.*?)'
+                if len(common_suff) == 0:
+                    common_suff = ' *(.*)' #*\S+(( .\.)?[ -\']\S+)?)' #One/two word(s)
+                else:
+                    common_suff = '(.*?)' + common_suff
+                self.patterns.append(
+                    [direction,
+                     u"%s%s%s" % (common_pref, mid, common_suff),
+                     regex_tag])
+
+    def find_new_relations(self, ind_max = 30000000) :
+        """ Find new relations in the database, matching with a pattern from
+            self.patterns
+
+            Read at maximum ``ind_max`` lines (can be None, meaning the whole
+            file)
+        """
+        fobj = open(self.database)
+        relations = set([])
+        min_len = 4
+        fobj.readline() #Ignore the first line
+        for ind, line in enumerate(fobj):
+            if ind_max and ind > ind_max:
+                break
+            if ind % 50000 == 0:
+                self._vprint( "-" * 10 + '>', ind_max - ind)
+            lines = line.decode(self.database_encoding)
+            if self.marker:
+                lines = lines.split(self.marker)
+            else:
+                lines = [lines]
+            for sent in lines:
+                sent = sent.strip()
+                if not sent:
+                    continue
+                for pattern in self.patterns:
+                    match = re.search(pattern[1], sent)
+                    if match:
+                        tagged_sentence = tag_sentence(sent, self.tagger)
+                        joined_sentence = '###'.join([w for (w, _)
+                                                        in tagged_sentence])
+                        elt1, elt2 = {}, {}
+
+                        elt1['group'] = match.group(1).strip()
+                        elt1['tokened'] = unicode_tokenize(elt1['group'])
+                        joined_token = '###'.join(elt1['tokened'])
+                        index = joined_sentence.index(joined_token)
+                        start = joined_sentence[:index].count('###')
+                        elt1['tagged'] = [(w, t) for (w, t)
+                                                 in tagged_sentence[start:]]
+                        elt1['regex'] = pattern[2][0]
+
+
+
+                        elt2['group'] = match.group(2).strip()
+                        elt2['tokened'] = unicode_tokenize(elt2['group'])
+                        joined_token = '###'.join(elt2['tokened'])
+                        index = joined_sentence.index(joined_token)
+                        start = joined_sentence[:index].count('###')
+                        elt2['tagged'] = [(w, t) for (w, t)
+                                                 in tagged_sentence[start:]]
+                        elt2['regex'] = pattern[2][1]
+
+                        if pattern[0]:
+                            elt1, elt2 = elt2, elt1
+
+                        if (not elt1['regex'] and len(elt1) < min_len) or \
+                           (not elt2['regex'] and len(elt2) < min_len):
+                            continue
+
+                        if elt1['regex']:
+                            grammar = r"""user_regex : { %s }""" % elt1['regex']
+                            parser = RegexpParser(grammar)
+                            prod = parser.parse(elt1['tagged']).productions()
+                            if len(prod) == 1: # Nothing has been produced
+                                continue
+                            else:
+                                elt1['group'] = ' '.join([w for (w, _)
+                                                                in prod[1].rhs()])
+                        if elt2['regex']:
+                            grammar = r"""user_regex : { %s }""" % elt2['regex']
+                            parser = RegexpParser(grammar)
+                            prod = parser.parse(elt2['tagged']).productions()
+                            if len(prod) == 1: # Nothing has been produced
+                                continue
+                            else:
+                                elt2['group'] = ' '.join([w for (w, _)
+                                                                in prod[1].rhs()])
+
+                        self._vprint( elt1['group'].encode('utf-8'),
+                                     "###",  elt2['group'].encode('utf-8'),
+                                     "###", sent.encode('utf-8'))
+                        yield (((elt1['group'], elt1['regex']), \
+                                (elt2['group'], elt2['regex'])))
+                        break
+
+    def _save_relations(self, relations, filename, encoding = 'utf-8'):
+        """ Save relations into filename
+        """
+        fobj = open(filename, 'a')
+        for ((elt1, _), (elt2, _)) in relations:
+            fobj.write('%s ; %s\n' % (elt1.encode(encoding),
+                                      elt2.encode(encoding)))
+        fobj.close()
+        self._vprint("Relations saved")
+
+    def _save_patterns(self, filename, encoding = 'utf-8'):
+        """ Save self.patterns into filename
+        """
+        fobj = open(filename, 'a')
+        for (direction, regex, tags) in self.patterns:
+            tags = [t or 'NA' for t in tags]
+            fobj.write('%s ## %s ## %s\n' % (direction,
+                                             regex.encode(encoding),
+                                             '@'.join(tags).encode(encoding)))
+        fobj.close()
+        self._vprint("Patterns saved")
+
+    def load_patterns(self, filename, encoding = 'utf-8'):
+        """ Load pattern from filename
+
+            (should be used only with a file made by _save_patterns() method)
+        """
+        self.patterns = []
+        fobj = open(filename)
+        for line in fobj:
+            direction, regex, raw_tags = line.decode(encoding).split(' ## ')
+            tags = []
+            for t in raw_tags[:-1].split('@'): ##remove the final \n
+                if t == 'NA':
+                    tags.append(None)
+                else:
+                    tags.append(t)
+
+            self.patterns.append((int(direction),
+                                  regex,
+                                  tuple(tags)))
+        fobj.close()
+        self._vprint("Patterns loaded")
+
+    def run(self, pattern_file = 'patterns_set',
+            extracted_file = 'relations_extracted'):
+        """ Call this method once a seed has been given to dipre, and enjoy
+
+            This method :
+                0) Calls find_sentences()
+                1) For each item in the result :
+                    0) split it into (sent, seed)
+                    1) append the result of build_occurrences(seed, sent) to a
+                       reslist
+                2) Calls extract_patterns(reslist)
+                3) Ask you to check if the regex patterns are acceptable
+                4) Save regexes into pattern_file
+                5) Call find_new_relations()
+                6) Save results into extracted_file
+                7) Set the results as new seeds
+                8) Ask you if you want to go to step 0 or exit here
+
+        """
+        stop = ''
+        if not self.seeds:
+            print "Hey guy, I'm not a wizard. You have to give me some food "
+            print "before I start. An advice, check my set_seeds() method ;)"
+            raise ValueError('Seeds are missing')
+        while not stop:
+            res = []
+            self._vprint("Building relations")
+            self._vprint(" Looking for sentences")
+            sents_seeds = self.find_sentences()
+            self._vprint(" Building occurrences")
+            for (sent, seed) in sents_seeds:
+                res.append(self.build_occurrences(seed, sent))
+            self._vprint("Building patterns")
+            self.extract_patterns(res)
+            if not self.patterns:
+                self._vprint("No patterns found")
+                return
+            else :
+                print('Please, check patterns before continuing…')
+                #Element n°1 of patterns (ie the regex) is editable
+                self.patterns = check_list(self.patterns, [1])
+            if not self.patterns:
+                return
+            self._save_patterns(pattern_file)
+            self._vprint("Looking for new relations")
+            relations = [r for r in self.find_new_relations()]
+            if relations:
+                self._vprint(" Storing relations")
+                self._save_relations(relations, extracted_file)
+                newseeds = set([(e1, e2) for e1, e2 in relations])
+                self.set_seeds(newseeds.difference(self.seeds))
+                stop = raw_input('  [Enter to continue] ')
+            else:
+                break
+            if not self.seeds:
+                break
+
+if __name__ == "__main__":
+    filename = u'/home/schabot/dev/data/db/short_abstracts_en_uris_fr.nt'
+    dipre = Dipre(filename, marker = '> "', database_encoding = 'unicode_escape',
+                  verbose = True,
+                  tagger='/home/schabot/dev/tools/french_bigramtagger_LOC.pkl')
+    seeds = set([
+                    ((u'Victor Hugo', '<NAM>+'), (u'les Misérables', None)),
+                ])
+    dipre.set_seeds(seeds)
+    dipre.run()