TextCNN(文本分类)

TextCNN(文本分类)TextCNN网络结构如图所示:利用TextCNN做文本分类基本流程(以句子分类为例):(1)将句子转成词,利用词建立字典(2)词转成向量(word2vec,Glove,bert,nn.embedding)(3)句子补0操作变成等长(4)建TextCNN模型,训练,测试TextCNN按照流程的一个例子。1,预测结果不是很好,句子太少2,没有用到复杂的word…

大家好,又见面了,我是你们的朋友全栈君。

TextCNN网络结构如图所示:

TextCNN(文本分类)

 利用TextCNN做文本分类基本流程(以句子分类为例):

(1)将句子转成词,利用词建立字典

(2)词转成向量(word2vec,Glove,bert,nn.embedding)

(3)句子补0操作变成等长

(4)建TextCNN模型,训练,测试

TextCNN按照流程的一个例子。

1,预测结果不是很好,句子太少 

2,没有用到复杂的word2vec的模型

3,句子太少,没有eval function。

import torch
import torch.nn as nn
import torch.nn.functional as F
 
sentence = [['The movie is great'],['Tom has a headache today'],['I think the apple is bad'],\
            ['You are so beautiful']]
Label = [1,0,0,1]
 
test_sentence = ['The apple is great']
 
class sentence2id:
    def __init__(self,sentence):
        self.sentence = sentence
        self.dic = {}
        self.words = []
        self.words_num = 0
    
    def sen2sen(self,sentence): ##大写转小写
        senten = []
        if type(sentence[0])== list:
            for sen in sentence:
                sen = sen[0].lower()
                senten.append([sen])           
        else:
            senten.append(sentence)
            senten = self.sen2sen(senten)
        return senten
         
            
    def countword(self): ##统计单词个数 
        ############[建库过程不涉及到test模块]#############
        for sen in self.sentence:
            sen = sen[0].split(' ') ##空格分隔
            for word in sen:
                self.words.append(word.lower())
        self.words = list(set(self.words))
        self.words = sorted(self.words)
        self.words_num = len(self.words)
        return self.words,self.words_num
    
    def word2id(self): ### 创建词汇表
        flag = 1
        for word in self.words:
            if flag <= self.words_num:
                self.dic[word] = flag
                flag += 1
        #print(self.dic)
        return self.dic
    
    def sen2id(self,sentence): ###
        sentence = self.sen2sen(sentence)
        sentoid = []
        for sen in sentence:
            senten = []
            for word in sen[0].split():
                senten.append(self.dic[word])
            sentoid.append(senten)
        return sentoid
                
def padded(sentence,pad_token): #token'<pad>'
    max_len = len(sentence[0])
    for i in range(0,len(sentence)-1):
        if max_len < len(sentence[i+1]):
            max_len = len(sentence[i+1])
        i += 1
    for i in range(0,len(sentence)):
        for j in range(0,max_len-len(sentence[i])):
            sentence[i].append(pad_token)
    return sentence
 
 
class ModelEmbeddings(nn.Module):
    def __init__(self,words_num,embed_size,pad_token): 
        super(ModelEmbeddings, self).__init__()
        self.words_num = words_num
        self.embed_size = embed_size
        self.Embedding = nn.Embedding(words_num,embed_size,pad_token)
 
class textCNN(nn.Module):
    def __init__(self,words_num,embed_size,class_num,dropout_rate=0.1):
        super(textCNN, self).__init__()
        self.words_num = words_num
        self.embed_size = embed_size 
        self.class_num = class_num
        
        self.conv1 = nn.Conv2d(1,3,(2,self.embed_size)) ###in_channels, out_channels, kernel_size
        self.conv2 = nn.Conv2d(1,3,(3,self.embed_size)) 
        self.conv3 = nn.Conv2d(1,3,(4,self.embed_size))
        
        self.max_pool1 = nn.MaxPool1d(5)
        self.max_pool2 = nn.MaxPool1d(4)
        self.max_pool3 = nn.MaxPool1d(3)
        
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(3*3*1,class_num)
        # 3 -> out_channels 3 ->kernel_size 1 ->max_pool
    
    def forward(self,sen_embed): #(batch,max_len,embed_size)
        sen_embed = sen_embed.unsqueeze(1) #(batch,in_channels,max_len,embed_size)
        
        conv1 = F.relu(self.conv1(sen_embed))  # ->(batch_size,out_channels.size,1)
        conv2 = F.relu(self.conv2(sen_embed))
        conv3 = F.relu(self.conv3(sen_embed))
    
        conv1 = torch.squeeze(conv1,dim=3)
        conv2 = torch.squeeze(conv2,dim=3)
        conv3 = torch.squeeze(conv3,dim=3)
        
        x1 = self.max_pool1(conv1)
        x2 = self.max_pool2(conv2)
        x3 = self.max_pool3(conv3)
        
        x = torch.cat((x1,x2),dim=1)
        x = torch.cat((x,x3),dim=1).squeeze(dim=2)
        
        output = self.linear(self.dropout(x))
        
        return output 
 
def train(model,sentence,label):
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    steps = 0
    best_acc = 0
    model.train()
    print ("-"*80)
    print('Training....')
    
    for epoch in range(1,2): ##2个epoch
        for step,x in enumerate(torch.split(sentence,1,dim=0)):
            target = torch.zeros(1)
            target[0] = label[step]
            target = torch.tensor(target,dtype=torch.long)
            optimizer.zero_grad()
            output  = model(x)
            loss = criterion(output, target)
            #loss.backward()
            loss.backward(retain_graph=True)
            optimizer.step()
            
            if step % 2 == 0:
                result = torch.max(output,1)[1].view(target.size())
                corrects = (result.data == target.data).sum()
                accuracy = corrects*100.0/1 ####1 is batch size 
                print('Epoch:',epoch,'step:',step,'- loss: %.6f'% loss.data.item(),\
                      'acc: %.4f'%accuracy)
    return model
        
 
if __name__ == '__main__':
    test = sentence2id(sentence)
    test.sen2sen(sentence)
    word,words_num = test.countword()
    test.word2id()
    
    sen_train = test.sen2id(sentence)
    sen_test = test.sen2id(test_sentence)
   
   
    X_train = torch.LongTensor((padded(sen_train,0)))
    X_test = torch.LongTensor((padded(sen_test,0)))
 
    Embedding = ModelEmbeddings(words_num+1,10,0)
    
    X_train_embed = Embedding.Embedding(X_train)
    X_test_embed = Embedding.Embedding(X_test)
    print(X_train_embed.size())
    #print(X_test_embed.size())
    
    ## TextCNN
    textcnn = textCNN(words_num,10,2)
    model = train(textcnn,X_train_embed,Label)  
    print(torch.max(model(X_test_embed),1)[1])

其中建立卷积层时,可以采用nn.ModuleList(),因为用起来不熟练,就直接展开了。等学好了,再补。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/154030.html原文链接:https://javaforall.net

(0)
上一篇 2022年6月28日 下午12:00
下一篇 2022年6月28日 下午12:00


相关推荐

  • sqrt函数用法c语言 linux,C语言中sqrt函数如何使用

    sqrt函数用法c语言 linux,C语言中sqrt函数如何使用C语言中sqrt函数如何使用发布时间:2020-04-3010:08:20来源:亿速云阅读:370作者:小新C语言中sqrt函数如何使用?相信有很多人都不太了解,今天小编为了让大家更加了解sqrt函数,所以给大家总结了以下内容,一起往下看吧。c语言sqrt函数的用法sqrt函数用于计算一个非负实数的平方根。sqrt的函数原型:在VC6.0中的math.h头文件的函数原型为doublesqrt…

    2022年5月1日
    211
  • 常量表达式是什么_const常量

    常量表达式是什么_const常量常量表达式值(constant-expressionvalue)。通常情况下,常量表达式值必须被一个常量表达式赋值,而跟常量表达式函数一样,常量表达式值在使用前必须被初始化。一、常量表达式1.1运行时常量性与编译时常量性在C++中,我们常常会遇到常量的概念。常量表示该值不可修改,通常是通过const关键字来修饰的。比如:constinti=3;const还可以修饰函数参数、函数返回值、函数本身、类等。在不同的使用条件下,const有不同的意义,不过大多数情况下,const描述的都

    2026年4月15日
    4
  • 从零配置 OpenClaw 飞书机器人:我的踩坑与成功之旅

    从零配置 OpenClaw 飞书机器人:我的踩坑与成功之旅

    2026年3月13日
    3
  • 学生选课管理系统_学生管理系统的主要内容

    学生选课管理系统_学生管理系统的主要内容文件下载地址:https://download.csdn.net/download/axiebuzhen/108950621.业务描述设计本系统,模拟学生选课的部分管理功能。学生入校注册后需统一记录学生个人基本信息,对于面向学生开设的相关课程需要记录每门课程的基本信息,每个任课教师规定其可主讲三门课程,学生选课时系统将相应的选课信息记录入库,考试结束后需在相应的选课记录中补上考试成绩。简化…

    2022年10月15日
    7
  • RabbitMQ与CMQ的使用与实战

    RabbitMQ与CMQ的使用与实战RabbitMQ Rabbitmq 的启动和关闭 rabbitmq server 前台启动服务 rabbitmq server detached 后台启动服务 常用 rabbitmqctls 停止服务端口号是 5672 可视化端口 15672 Linux 中查看正在运行的端口号 netstat tulpn 终止与启动应用 Rabbitmqctls app 启动引用

    2026年3月19日
    1
  • java九九乘法表代码_java 输出九九乘法表口诀的代码「建议收藏」

    java九九乘法表代码_java 输出九九乘法表口诀的代码「建议收藏」题目:输出9*9口诀。程序分析:分行与列考虑,共9行9列,i控制行,j控制列。程序设计:publicclassjiujiu{publicstaticvoidmain(String[]args){inti=0;intj=0;for(i=1;i<=9;i++){for(j=1;j<=9;j++)System.out.print(i+”*”+j+”=”+i*j+”\t”)…

    2022年7月15日
    20

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号