1、
清除字符串中的非字母和数字
re.sub(r'[^a-zA-Z0-9 ]', '', s)
2、将多个list合并为一个list
sum(list, [])
3set.union(*map(set, candidates_1))
set.union(A,B)返回两个集合的并集
*map(set, candidates)将candidates中的每个值映射为set,并传递每个参数
4、presumm模型如何将json_data转化为bert_data。
import os
import re
import json

def _get_ngrams(n, text):
    """Calcualtes n-grams.

    Args:
      n: which n-grams to calculate
      text: An array of tokens

    Returns:
      A set of n-grams
    """
    ngram_set = set()
    text_length = len(text)
    max_index_ngram_start = text_length - n
    for i in range(max_index_ngram_start + 1):
        ngram_set.add(tuple(text[i:i + n]))
    return ngram_set


def _get_word_ngrams(n, sentences):
    """Calculates word n-grams for multiple sentences.
    """
    assert len(sentences) > 0
    assert n > 0

    # words = _split_into_words(sentences)

    words = sum(sentences, [])
    # words = [w for w in words if w not in stopwords]
    return _get_ngrams(n, words)

def cal_rouge(evaluated_ngrams, reference_ngrams):
    reference_count = len(reference_ngrams)
    evaluated_count = len(evaluated_ngrams)

    overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)        # 返回两个集合的交集

    overlapping_count = len(overlapping_ngrams)

    if evaluated_count == 0:
        precision = 0.0
    else:
        precision = overlapping_count / evaluated_count

    if reference_count == 0:
        recall = 0.0
    else:
        recall = overlapping_count / reference_count

    f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
    return {"f": f1_score, "p": precision, "r": recall}

def greedy_selection(doc_sent_list, abstract_sent_list, summary_size):
    def _rouge_clean(s):
        # 清除字符串中的非字母和数字
        return re.sub(r'[^a-zA-Z0-9 ]', '', s)

    max_rouge = 0.0
    abstract = sum(abstract_sent_list, [])      # 放到一个list中

    abstract = _rouge_clean(' '.join(abstract)).split()
    sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]      # 二维list

    # 针对每句话,选择相应的ngrams
    evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
    reference_1grams = _get_word_ngrams(1, [abstract])
    evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
    reference_2grams = _get_word_ngrams(2, [abstract])
    # print(evaluated_1grams[:5])

    selected = []
    for s in range(summary_size):
        cur_max_rouge = max_rouge
        cur_id = -1
        for i in range(len(sents)):     # 针对每句话
            if (i in selected):
                continue
            c = selected + [i]

            candidates_1 = [evaluated_1grams[idx] for idx in c]     # 从c中提取出来的unigram
            candidates_1 = set.union(*map(set, candidates_1))       # 合并c中提取出的unigram

            candidates_2 = [evaluated_2grams[idx] for idx in c]
            candidates_2 = set.union(*map(set, candidates_2))

            rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
            rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
            rouge_score = rouge_1 + rouge_2
            if rouge_score > cur_max_rouge:
                cur_max_rouge = rouge_score
                cur_id = i
        if (cur_id == -1):
            return selected
        selected.append(cur_id)
        max_rouge = cur_max_rouge

    return sorted(selected)


def format_to_bert(json_file, max_src_nsents=100, save_file=None):

    # if (os.path.exists(save_file)):
    #     logger.info('Ignore %s' % save_file)
    #     return

    # logger.info('Processing %s' % json_file)
    jobs = json.load(open(json_file))
    datasets = []
    for d in jobs:
        source, tgt = d['src'], d['tgt']
        # 选出得分最高的前3个句子
        sent_labels = greedy_selection(source[:max_src_nsents], tgt, 3)
        print(sent_labels)

if __name__ == '__main__':
    corpus_type = "train"
    json_file = "json_data/cnndm_sample.train.0.json"

    format_to_bert(json_file)


5"""
RuntimeError: Subtraction, the `-` operator, with a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.
解决方案:
将mask_src = 1 - (src == 0)修改为mask_src = ~(src == 0)
原因:
不允许对bool变量进行“-”操作,如果需要对bool变量进行反转,则使用“~”操作
"""
6# 使用glob找到所有匹配的文件路径列表
# [0-9]表示匹配指定范围内的字符,即0~9
# *表示匹配0个或多个字符
pts = sorted(glob.glob(args.bert_data_path + '.' + corpus_type + '.[0-9]*.pt'))
7# 在PreSumm代码中,训练集存在多个pt文件,读入的方式利用到了生成器,所以不太好理解,特此记录一下
import gc
class Dataloader(object):
    def __init__(self, datasets):
        # datasets是个生成器-yield
        # [{},{},...,{}]
        self.datasets = datasets
        self.cur_iter = self._next_dataset_iterator(datasets)
        assert self.cur_iter is not None

    def __iter__(self):
        dataset_iter = (d for d in self.datasets)
        print(list(dataset_iter))
        # exit()
        while self.cur_iter is not None:
            for batch in self.cur_iter:
                yield batch
            self.cur_iter = self._next_dataset_iterator(dataset_iter)

    def _next_dataset_iterator(self, dataset_iter):
        try:
            # Drop the current dataset for decreasing memory
            if hasattr(self, "cur_dataset"):
                self.cur_dataset = None
                gc.collect()
                del self.cur_dataset
                gc.collect()
		    # 需要注意的是,next取下一个值,取完后生成器中就不存在这个值了
            self.cur_dataset = next(dataset_iter)

        except StopIteration:
            return None

        return self.cur_dataset

def get_data(input):
    for data in input:
        yield data

if __name__ == "__main__":
    input = [{"name": "tyler1", "age": 18, "home": "china1"},
             {"name": "tyler2", "age": 19, "home": "china2"},
             {"name": "tyler3", "age": 20, "home": "china3"}]
    datasets = get_data(input)
    data_loader = Dataloader(datasets)
    for i, batch in enumerate(data_loader):
        print(batch)
8# map_location=lambda storage, loc: storage
# 当预训练的参数为GPU,模型要被加载到CPU时,调用上述命令
checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
9、抽取式模块-整体流程总结
"""
(1)将json数据转换为bert输入格式,其中使用了greedy_selection函数处理数据集;
(2)将input输入bert中,利用各句子CLS对应的token embedding进行训练;
(3)训练和验证阶段使用loss保存模型;
(4)使用pyrouge计算得分。
"""