인공지능 공부/NLP 연구

(NLP 연구) The Long-Document Transformer 03.31 (LSH)

앨런튜링_ 2022. 4. 1. 15:29
import glob
import os
import io
import string
import re
import random
import spacy
import torchtext
from torchtext.vocab import Vectors
import re
import numpy as np
import itertools
from random import shuffle
def create_hash_func(size: int):
    hash_ex = list(range(1, len(vocab)+1))
    shuffle(hash_ex)
    return hash_ex

def build_minhash_func(vocab_size: int, nbits: int):
    hashes = []
    for _ in range(nbits):
        hashes.append(create_hash_func(vocab_size))
    return hashes

def create_hash(vector: list):
    signature_list = []
    signature = []
    for j in range(len(vector)):
        for func in minhash_func:
            for i in range(1, len(vocab)+1):
                idx = func.index(i)
                signature_val = vector[j][idx]
                if signature_val == 1:
                    signature.append(i)
                    break
        signature_list.append(signature)
        signature = []
    return signature_list

def onehot_shingle_func(vocab : list, text: list) :
    onehot_full_list = []
    onehot_list =[]
    for j in range(len(text)):
        one_hot = [1 if x in text[j] else 0 for x in vocab]
        onehot_list.append(one_hot)
    return onehot_list

def list_overlap_del(input_list):
    result_list = []
    list_temp = []
    for i in range(len(input_list)):
        if i == 0:
            list_temp.append(input_list[i])                        
        elif list_temp[-1] != input_list[i]:
            list_temp.append(input_list[i])           
    return list_temp

def max_index_func(total_sig):    
    aa = np.zeros((len(total_sig), len(total_sig)))
    for i in range(len(total_sig)):
        for j in range(len(total_sig)):
                aa[i][j] = format(jaccard(set(total_sig[i]), set(total_sig[j])), ".3f")


    max_index = []
    rid_index = []

    for i in range(len(total_sig)):
        choose_index = np.argmax(aa[i])
        rid_index.append(choose_index)
        for k in rid_index:
            if (i<len(total_sig)-1):
                aa[i+1][k] = 0
            else :
                pass
        #modifiedArray = np.delete(aa[i], rid_index)
        #num_index = np.argmax(modifiedArray).item()
        max_index.append(choose_index)

    return max_index

def sorted_func(max_index, text):
    result_list = []
    for i in range(len(text)):
        result_list.append(i)
        result_list.append(max_index[i])
    return result_list
        
def list_overlap_del(input_list):
    result_list = []
    list_temp = []
    for i in range(len(input_list)):
        if i == 0:
            list_temp.append(input_list[i])                        
        elif list_temp[-1] != input_list[i]:
            list_temp.append(input_list[i])           
    return list_temp

def jaccard(x,y):
    if  ((x.intersection(y) == 0 or len(x.union(y)) ==0) or (len(x.intersection(y)) / len(x.union(y)) ==1)):
        return 0
    
    else:
        result = len(x.intersection(y)) / len(x.union(y))
        return result
def get_split(text: str):
    text = text.replace('\t', " ")
    text = text.replace('\\', "")
    text = text.replace('--', "")
    text = re.sub('<br />', '', text)    
    text = text.split('.')
        
    return text
def vocab_func(text: list, k : int):
    vocab_list = []
    for i in range(len(text)):
        text_shingle = shingle(text[i], k)
        vocab_list.append(text_shingle)

    return vocab_list
def create_hash_func(size: int):
    hash_ex = list(range(1, len(vocab)+1))
    shuffle(hash_ex)
    return hash_ex

def build_minhash_func(vocab_size: int, nbits: int):
    hashes = []
    for _ in range(nbits):
        hashes.append(create_hash_func(vocab_size))
    return hashes

def create_hash(vector: list):
    signature_list = []
    signature = []
    for j in range(len(vector)):
        for func in minhash_func:
            for i in range(1, len(vocab)+1):
                idx = func.index(i)
                signature_val = vector[j][idx]
                if signature_val == 1:
                    signature.append(i)
                    break
        signature_list.append(signature)
        signature = []
    return signature_list

def onehot_shingle_func(vocab : list, text: list) :
    onehot_full_list = []
    onehot_list =[]
    for j in range(len(text)):
        one_hot = [1 if x in text[j] else 0 for x in vocab]
        onehot_list.append(one_hot)
    return onehot_list

def list_overlap_del(input_list):
    result_list = []
    list_temp = []
    for i in range(len(input_list)):
        if i == 0:
            list_temp.append(input_list[i])                        
        elif list_temp[-1] != input_list[i]:
            list_temp.append(input_list[i])           
    return list_temp

def max_index_func(total_sig):    
    aa = np.zeros((len(total_sig), len(total_sig)))
    for i in range(len(total_sig)):
        for j in range(len(total_sig)):
                aa[i][j] = format(jaccard(set(total_sig[i]), set(total_sig[j])), ".3f")


    max_index = []
    rid_index = []

    for i in range(len(total_sig)):
        choose_index = np.argmax(aa[i])
        rid_index.append(choose_index)
        for k in rid_index:
            if (i<len(total_sig)-1):
                aa[i+1][k] = 0
            else :
                pass
        #modifiedArray = np.delete(aa[i], rid_index)
        #num_index = np.argmax(modifiedArray).item()
        max_index.append(choose_index)

    return max_index

def sorted_func(max_index, text):
    result_list = []
    for i in range(len(text)):
        result_list.append(i)
        result_list.append(max_index[i])
    return result_list
        
def list_overlap_del(input_list):
    result_list = []
    list_temp = []
    for i in range(len(input_list)):
        if i == 0:
            list_temp.append(input_list[i])                        
        elif list_temp[-1] != input_list[i]:
            list_temp.append(input_list[i])           
    return list_temp

def jaccard(x,y):
    if  ((x.intersection(y) == 0 or len(x.union(y)) ==0) or (len(x.intersection(y)) / len(x.union(y)) ==1)):
        return 0
    
    else:
        result = len(x.intersection(y)) / len(x.union(y))
        return result
f = open('./data/IMDb_train.tsv', 'w')

path = './data/aclImdb/train/pos/'
for fname in glob.glob(os.path.join(path, '*.txt')):
    with io.open(fname, 'r', encoding="utf-8") as ff:
        text = ff.readline()
        text_list = []
        text = text.replace('\t', " ")
        
        text_list.append(get_split(text)) 
        print(text_list)
        vocab = list(shingle(text_list[0], 2))
        hash_ex = list(range(1,len(vocab)+1))
        shuffle(hash_ex)
        onehot_shingle = onehot_shingle_func(vocab, text_list[0])
        minhash_func = build_minhash_func(len(vocab), 20)
        total_sig = create_hash(onehot_shingle)
        max_index_list = max_index_func(total_sig)
        max_sorted = sorted_func(max_index_list, text_list[0])
        max_remove = list_overlap_del(max_sorted)
        text = ""
        
        for i in max_remove:
            text = text + " " + text_list[0][i]             
        
        text = text+'\t'+'1'+'\t'+'\n'
        print("pos : " + text)
        f.write(text)
        
path = './data/aclImdb/train/neg/'
for fname in glob.glob(os.path.join(path, '*.txt')):
    with io.open(fname, 'r', encoding="utf-8") as ff:
        text = ff.readline()
        text_list = []
        text = text.replace('\t', " ")
        
        text_list.append(get_split(text)) 
        
        vocab = list(shingle(text_list[0], 2))
        hash_ex = list(range(1,len(vocab)+1))
        shuffle(hash_ex)
        onehot_shingle = onehot_shingle_func(vocab, text_list[0])
        minhash_func = build_minhash_func(len(vocab), 20)
        total_sig = create_hash(onehot_shingle)
        max_index_list = max_index_func(total_sig)
        max_sorted = sorted_func(max_index_list, text_list[0])
        max_remove = list_overlap_del(max_sorted)
        text = ""
        
        for i in max_remove:
            text = text + text_list[0][i] + " "                  
        print("neg : " + text)
        text = text+'\t'+'0'+'\t'+'\n'
       
        f.write(text)


f.close()