TextCNN(文本分类)

TextCNN(文本分类)TextCNN网络结构如图所示:利用TextCNN做文本分类基本流程(以句子分类为例):(1)将句子转成词,利用词建立字典(2)词转成向量(word2vec,Glove,bert,nn.embedding)(3)句子补0操作变成等长(4)建TextCNN模型,训练,测试TextCNN按照流程的一个例子。1,预测结果不是很好,句子太少2,没有用到复杂的word…

大家好,又见面了,我是你们的朋友全栈君。

TextCNN网络结构如图所示:

TextCNN(文本分类)

 利用TextCNN做文本分类基本流程(以句子分类为例):

(1)将句子转成词,利用词建立字典

(2)词转成向量(word2vec,Glove,bert,nn.embedding)

(3)句子补0操作变成等长

(4)建TextCNN模型,训练,测试

TextCNN按照流程的一个例子。

1,预测结果不是很好,句子太少 

2,没有用到复杂的word2vec的模型

3,句子太少,没有eval function。

import torch
import torch.nn as nn
import torch.nn.functional as F
 
sentence = [['The movie is great'],['Tom has a headache today'],['I think the apple is bad'],\
            ['You are so beautiful']]
Label = [1,0,0,1]
 
test_sentence = ['The apple is great']
 
class sentence2id:
    def __init__(self,sentence):
        self.sentence = sentence
        self.dic = {}
        self.words = []
        self.words_num = 0
    
    def sen2sen(self,sentence): ##大写转小写
        senten = []
        if type(sentence[0])== list:
            for sen in sentence:
                sen = sen[0].lower()
                senten.append([sen])           
        else:
            senten.append(sentence)
            senten = self.sen2sen(senten)
        return senten
         
            
    def countword(self): ##统计单词个数 
        ############[建库过程不涉及到test模块]#############
        for sen in self.sentence:
            sen = sen[0].split(' ') ##空格分隔
            for word in sen:
                self.words.append(word.lower())
        self.words = list(set(self.words))
        self.words = sorted(self.words)
        self.words_num = len(self.words)
        return self.words,self.words_num
    
    def word2id(self): ### 创建词汇表
        flag = 1
        for word in self.words:
            if flag <= self.words_num:
                self.dic[word] = flag
                flag += 1
        #print(self.dic)
        return self.dic
    
    def sen2id(self,sentence): ###
        sentence = self.sen2sen(sentence)
        sentoid = []
        for sen in sentence:
            senten = []
            for word in sen[0].split():
                senten.append(self.dic[word])
            sentoid.append(senten)
        return sentoid
                
def padded(sentence,pad_token): #token'<pad>'
    max_len = len(sentence[0])
    for i in range(0,len(sentence)-1):
        if max_len < len(sentence[i+1]):
            max_len = len(sentence[i+1])
        i += 1
    for i in range(0,len(sentence)):
        for j in range(0,max_len-len(sentence[i])):
            sentence[i].append(pad_token)
    return sentence
 
 
class ModelEmbeddings(nn.Module):
    def __init__(self,words_num,embed_size,pad_token): 
        super(ModelEmbeddings, self).__init__()
        self.words_num = words_num
        self.embed_size = embed_size
        self.Embedding = nn.Embedding(words_num,embed_size,pad_token)
 
class textCNN(nn.Module):
    def __init__(self,words_num,embed_size,class_num,dropout_rate=0.1):
        super(textCNN, self).__init__()
        self.words_num = words_num
        self.embed_size = embed_size 
        self.class_num = class_num
        
        self.conv1 = nn.Conv2d(1,3,(2,self.embed_size)) ###in_channels, out_channels, kernel_size
        self.conv2 = nn.Conv2d(1,3,(3,self.embed_size)) 
        self.conv3 = nn.Conv2d(1,3,(4,self.embed_size))
        
        self.max_pool1 = nn.MaxPool1d(5)
        self.max_pool2 = nn.MaxPool1d(4)
        self.max_pool3 = nn.MaxPool1d(3)
        
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(3*3*1,class_num)
        # 3 -> out_channels 3 ->kernel_size 1 ->max_pool
    
    def forward(self,sen_embed): #(batch,max_len,embed_size)
        sen_embed = sen_embed.unsqueeze(1) #(batch,in_channels,max_len,embed_size)
        
        conv1 = F.relu(self.conv1(sen_embed))  # ->(batch_size,out_channels.size,1)
        conv2 = F.relu(self.conv2(sen_embed))
        conv3 = F.relu(self.conv3(sen_embed))
    
        conv1 = torch.squeeze(conv1,dim=3)
        conv2 = torch.squeeze(conv2,dim=3)
        conv3 = torch.squeeze(conv3,dim=3)
        
        x1 = self.max_pool1(conv1)
        x2 = self.max_pool2(conv2)
        x3 = self.max_pool3(conv3)
        
        x = torch.cat((x1,x2),dim=1)
        x = torch.cat((x,x3),dim=1).squeeze(dim=2)
        
        output = self.linear(self.dropout(x))
        
        return output 
 
def train(model,sentence,label):
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    steps = 0
    best_acc = 0
    model.train()
    print ("-"*80)
    print('Training....')
    
    for epoch in range(1,2): ##2个epoch
        for step,x in enumerate(torch.split(sentence,1,dim=0)):
            target = torch.zeros(1)
            target[0] = label[step]
            target = torch.tensor(target,dtype=torch.long)
            optimizer.zero_grad()
            output  = model(x)
            loss = criterion(output, target)
            #loss.backward()
            loss.backward(retain_graph=True)
            optimizer.step()
            
            if step % 2 == 0:
                result = torch.max(output,1)[1].view(target.size())
                corrects = (result.data == target.data).sum()
                accuracy = corrects*100.0/1 ####1 is batch size 
                print('Epoch:',epoch,'step:',step,'- loss: %.6f'% loss.data.item(),\
                      'acc: %.4f'%accuracy)
    return model
        
 
if __name__ == '__main__':
    test = sentence2id(sentence)
    test.sen2sen(sentence)
    word,words_num = test.countword()
    test.word2id()
    
    sen_train = test.sen2id(sentence)
    sen_test = test.sen2id(test_sentence)
   
   
    X_train = torch.LongTensor((padded(sen_train,0)))
    X_test = torch.LongTensor((padded(sen_test,0)))
 
    Embedding = ModelEmbeddings(words_num+1,10,0)
    
    X_train_embed = Embedding.Embedding(X_train)
    X_test_embed = Embedding.Embedding(X_test)
    print(X_train_embed.size())
    #print(X_test_embed.size())
    
    ## TextCNN
    textcnn = textCNN(words_num,10,2)
    model = train(textcnn,X_train_embed,Label)  
    print(torch.max(model(X_test_embed),1)[1])

其中建立卷积层时,可以采用nn.ModuleList(),因为用起来不熟练,就直接展开了。等学好了,再补。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/154030.html原文链接:https://javaforall.net

(0)
上一篇 2022年6月28日 下午12:00
下一篇 2022年6月28日 下午12:00


相关推荐

  • JRebel 激活地址

    JRebel 激活地址url 地址 http jrebel qekang com f361d7c3 4364 4070 8eca b3d745 邮箱 自己可用的地址

    2026年3月18日
    2
  • java反射的原理,作用

    什么是反射,反射原理Java反射的原理:java类的执行需要经历以下过程,编译:.java文件编译后生成.class字节码文件加载:类加载器负责根据一个类的全限定名来读取此类的二进制字节流到JVM内部,并存储在运行时内存区的方法区,然后将其转换为一个与目标类型对应的java.lang.Class对象实例连接:细分三步验证:格式(class文件规范)语义(final类是否有子类)…

    2022年4月11日
    46
  • python垃圾回收机制原理

    python垃圾回收机制原理#python垃圾回收机制详解一、概述:  python的GC模块主要运用了“引用计数(referencecounting)”来跟踪和回收垃圾。在引用计数的基础上,还可以通过标记清除(markandsweep)解决容器(这里的容器值指的不是docker,而是数组,字典,元组这样的对象)对象可能产生的循环引用的问题。通过“分代回收(generationcollection)”以空间换取时间来进一步提高垃圾回收的效率。二、垃圾回收三种机制  1、引用计数  在Python中,大多数对象的生命周

    2022年6月24日
    32
  • OpenClaw 常见报错与解决方案

    OpenClaw 常见报错与解决方案

    2026年3月17日
    14
  • SpringBoot之SpringApplication初始化

    SpringBoot之SpringApplication初始化SpringApplication的初始化之前已经分析了引导类上的@SpringBootApplication注解,接下来继续分析main方法,只调用了一句SpringApplication.run(SpringbootApplication.class,args),就启动了web容器,我们看看run方法里面做了什么publicstaticConfigurableApplicationContextrun(Class<?>[]primarySources,String[]ar

    2025年8月26日
    9
  • c语言里面的枚举有啥作用,C语言枚举enum

    c语言里面的枚举有啥作用,C语言枚举enumC 语言枚举 enum 教程枚举是枚举的作用就是给我们常用的 C 语言枚举 enum 定义详解语法 enum 枚举名 枚举元素 1 枚举元素 2 枚举元素 3 参数参数描述 enum 定义枚举类型所使用的关键字 枚举名枚举的变量名 枚举元素 1 枚举元素 2 枚举元素 3 枚举的元素列表 说明我们使用 enum 关键字 定义了一个枚举变量 该枚举变量有三个元素 C 语言枚举 enum 变量定义详解语法 enum 枚举名 varn

    2025年7月12日
    4

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号