TextCNN(文本分类)

TextCNN(文本分类)TextCNN网络结构如图所示:利用TextCNN做文本分类基本流程(以句子分类为例):(1)将句子转成词,利用词建立字典(2)词转成向量(word2vec,Glove,bert,nn.embedding)(3)句子补0操作变成等长(4)建TextCNN模型,训练,测试TextCNN按照流程的一个例子。1,预测结果不是很好,句子太少2,没有用到复杂的word…

大家好,又见面了,我是你们的朋友全栈君。

TextCNN网络结构如图所示:

TextCNN(文本分类)

 利用TextCNN做文本分类基本流程(以句子分类为例):

(1)将句子转成词,利用词建立字典

(2)词转成向量(word2vec,Glove,bert,nn.embedding)

(3)句子补0操作变成等长

(4)建TextCNN模型,训练,测试

TextCNN按照流程的一个例子。

1,预测结果不是很好,句子太少 

2,没有用到复杂的word2vec的模型

3,句子太少,没有eval function。

import torch
import torch.nn as nn
import torch.nn.functional as F
 
sentence = [['The movie is great'],['Tom has a headache today'],['I think the apple is bad'],\
            ['You are so beautiful']]
Label = [1,0,0,1]
 
test_sentence = ['The apple is great']
 
class sentence2id:
    def __init__(self,sentence):
        self.sentence = sentence
        self.dic = {}
        self.words = []
        self.words_num = 0
    
    def sen2sen(self,sentence): ##大写转小写
        senten = []
        if type(sentence[0])== list:
            for sen in sentence:
                sen = sen[0].lower()
                senten.append([sen])           
        else:
            senten.append(sentence)
            senten = self.sen2sen(senten)
        return senten
         
            
    def countword(self): ##统计单词个数 
        ############[建库过程不涉及到test模块]#############
        for sen in self.sentence:
            sen = sen[0].split(' ') ##空格分隔
            for word in sen:
                self.words.append(word.lower())
        self.words = list(set(self.words))
        self.words = sorted(self.words)
        self.words_num = len(self.words)
        return self.words,self.words_num
    
    def word2id(self): ### 创建词汇表
        flag = 1
        for word in self.words:
            if flag <= self.words_num:
                self.dic[word] = flag
                flag += 1
        #print(self.dic)
        return self.dic
    
    def sen2id(self,sentence): ###
        sentence = self.sen2sen(sentence)
        sentoid = []
        for sen in sentence:
            senten = []
            for word in sen[0].split():
                senten.append(self.dic[word])
            sentoid.append(senten)
        return sentoid
                
def padded(sentence,pad_token): #token'<pad>'
    max_len = len(sentence[0])
    for i in range(0,len(sentence)-1):
        if max_len < len(sentence[i+1]):
            max_len = len(sentence[i+1])
        i += 1
    for i in range(0,len(sentence)):
        for j in range(0,max_len-len(sentence[i])):
            sentence[i].append(pad_token)
    return sentence
 
 
class ModelEmbeddings(nn.Module):
    def __init__(self,words_num,embed_size,pad_token): 
        super(ModelEmbeddings, self).__init__()
        self.words_num = words_num
        self.embed_size = embed_size
        self.Embedding = nn.Embedding(words_num,embed_size,pad_token)
 
class textCNN(nn.Module):
    def __init__(self,words_num,embed_size,class_num,dropout_rate=0.1):
        super(textCNN, self).__init__()
        self.words_num = words_num
        self.embed_size = embed_size 
        self.class_num = class_num
        
        self.conv1 = nn.Conv2d(1,3,(2,self.embed_size)) ###in_channels, out_channels, kernel_size
        self.conv2 = nn.Conv2d(1,3,(3,self.embed_size)) 
        self.conv3 = nn.Conv2d(1,3,(4,self.embed_size))
        
        self.max_pool1 = nn.MaxPool1d(5)
        self.max_pool2 = nn.MaxPool1d(4)
        self.max_pool3 = nn.MaxPool1d(3)
        
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(3*3*1,class_num)
        # 3 -> out_channels 3 ->kernel_size 1 ->max_pool
    
    def forward(self,sen_embed): #(batch,max_len,embed_size)
        sen_embed = sen_embed.unsqueeze(1) #(batch,in_channels,max_len,embed_size)
        
        conv1 = F.relu(self.conv1(sen_embed))  # ->(batch_size,out_channels.size,1)
        conv2 = F.relu(self.conv2(sen_embed))
        conv3 = F.relu(self.conv3(sen_embed))
    
        conv1 = torch.squeeze(conv1,dim=3)
        conv2 = torch.squeeze(conv2,dim=3)
        conv3 = torch.squeeze(conv3,dim=3)
        
        x1 = self.max_pool1(conv1)
        x2 = self.max_pool2(conv2)
        x3 = self.max_pool3(conv3)
        
        x = torch.cat((x1,x2),dim=1)
        x = torch.cat((x,x3),dim=1).squeeze(dim=2)
        
        output = self.linear(self.dropout(x))
        
        return output 
 
def train(model,sentence,label):
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    steps = 0
    best_acc = 0
    model.train()
    print ("-"*80)
    print('Training....')
    
    for epoch in range(1,2): ##2个epoch
        for step,x in enumerate(torch.split(sentence,1,dim=0)):
            target = torch.zeros(1)
            target[0] = label[step]
            target = torch.tensor(target,dtype=torch.long)
            optimizer.zero_grad()
            output  = model(x)
            loss = criterion(output, target)
            #loss.backward()
            loss.backward(retain_graph=True)
            optimizer.step()
            
            if step % 2 == 0:
                result = torch.max(output,1)[1].view(target.size())
                corrects = (result.data == target.data).sum()
                accuracy = corrects*100.0/1 ####1 is batch size 
                print('Epoch:',epoch,'step:',step,'- loss: %.6f'% loss.data.item(),\
                      'acc: %.4f'%accuracy)
    return model
        
 
if __name__ == '__main__':
    test = sentence2id(sentence)
    test.sen2sen(sentence)
    word,words_num = test.countword()
    test.word2id()
    
    sen_train = test.sen2id(sentence)
    sen_test = test.sen2id(test_sentence)
   
   
    X_train = torch.LongTensor((padded(sen_train,0)))
    X_test = torch.LongTensor((padded(sen_test,0)))
 
    Embedding = ModelEmbeddings(words_num+1,10,0)
    
    X_train_embed = Embedding.Embedding(X_train)
    X_test_embed = Embedding.Embedding(X_test)
    print(X_train_embed.size())
    #print(X_test_embed.size())
    
    ## TextCNN
    textcnn = textCNN(words_num,10,2)
    model = train(textcnn,X_train_embed,Label)  
    print(torch.max(model(X_test_embed),1)[1])

其中建立卷积层时,可以采用nn.ModuleList(),因为用起来不熟练,就直接展开了。等学好了,再补。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/154030.html原文链接:https://javaforall.net

(0)
上一篇 2022年6月28日 下午12:00
下一篇 2022年6月28日 下午12:00


相关推荐

  • 简单完整讲述Servlet生命周期

    简单完整讲述Servlet生命周期servlet 生命周期过程 1 加载 web xml 文件 当前只去解析 xml 文件 知道 servlet 的存在 此时还没有去创建 servlet 声明 servlet servlet servlet 的别名 servlet name first servlet name servlet class com etime servlet FirstServlet servlet class servlet

    2026年3月19日
    2
  • UML活动图

    UML活动图面向对象的软件开发方法的第一步:业务建模<–使用活动图转载:https://www.cnblogs.com/xiaolongbao-lzh/p/4591953.html活动图概述•活动图和交互图是UML中对系统动态方面建模的两种主要形式•交互图强调的是对象到对象的控制流,而活动图则强调的是从活动到活动的控制流•活动图是一种表述过程基理、业务过程以及工作流的技术。它可以用…

    2022年6月14日
    26
  • k8s(十)基本存储[通俗易懂]

    k8s(十)基本存储[通俗易懂]文章目录概述EmptyDirHostPathNFSk8s的数据存储概述在前面已经提到,容器的生命周期可能很短,会被频繁的创建和销毁。那么容器在销毁的时候,保存在容器中的数据也会被清除。这种结果对用户来说,在某些情况下是不乐意看到的。为了持久化保存容器中的数据,kubernetes引入了Volume的概念。Volume是Pod中能够被多个容器访问的共享目录,它被定义在Pod上,然后被一个Pod里面的多个容器挂载到具体的文件目录下,kubernetes通过Volume实现同一个Pod中不同容器之间的数据

    2022年8月9日
    4
  • 常量指针和指针常量的区别详解

    常量指针和指针常量的区别详解常量指针和指针常量的区别详解

    2026年3月26日
    2
  • velocity中的注释种类

    velocity中的注释种类各种编程语言都有相对应的注释 而 velocity 作为一种模板引擎也不例外 大体上 velocity 的主食类型分为如下几类 nbsp 单行注释 nbsp nbsp Thisisasingl nbsp 多行注释 nbsp nbsp nbsp nbsp nbsp Thusbeginsam linecomment Onlinevisito t nbsp nbsp nbsp seethistex

    2026年3月26日
    3
  • LDR命令

    LDR命令  LDR指令用于从内存中将一个32位的字读取到指令中的目标寄存器中,如果目标寄存器为PC,则可以实现“长”跳转。主要有一下3种方式使用:ldrr0,_startldrr0,=_startldrpc,_start  逐条分析:一、ldrr0,_start  从内存地址_start的地方,把其对应的命令执行对应的“执行码”读入到r0中。二、ldrr0,=_start  …

    2022年6月28日
    85

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号