PyCharm安装torch以及pytorch-pretrained-bert简单使用

PyCharm安装torch以及pytorch-pretrained-bert简单使用安装torch运行Pycharm中的代码时候提示ModuleNotFoundError:Nomodulenamed‘torch’。试了很多种方法都不行,然后进入官网查了下具体的安装方法,附上网址https://pytorch.org/get-started/previous-versions/。摘取一段放在这里供大家参考。#CUDA10.0pipinstalltorch===1.2.0torchvision===0.4.0-fhttps://download.pytorc

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

目录

安装torch

pytorch-pretrained-bert简单使用


安装torch

运行Pycharm中的代码时候提示ModuleNotFoundError: No module named ‘torch’。试了很多种方法都不行,然后进入官网查了下具体的安装方法,附上网址https://pytorch.org/get-started/previous-versions/。
摘取一段放在这里供大家参考。

# CUDA 10.0
pip install torch===1.2.0 torchvision===0.4.0 -f https://download.pytorch.org/whl/torch_stable.html

# CUDA 9.2
pip install torch==1.2.0+cu92 torchvision==0.4.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html

# CPU only
pip install torch==1.2.0+cpu torchvision==0.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html

pytorch-pretrained-bert简单使用

从下载模型权重开始

# 切换到你的anaconda gpu 环境
# source activate 你的conda环境名称
​
# 安装加载预训练模型&权重的包
pip install pytorch-pretrained-bert

接着就是下载模型权重文件了,pytorch-pretrained-bert官方下载地址太慢了…,推荐去kaggle下载L-12_H-768-A-12 uncase版本,下载地址在这里,里面有两个文件,都下载下来,并把模型参数权重的文件bert-base-uncased解压出来,然后放在你熟悉的硬盘下即可。

加载模型试试

from pytorch_pretrained_bert import BertModel, BertTokenizer
import numpy as np
import torch

# 加载bert的分词器
tokenizer = BertTokenizer.from_pretrained('E:/Projects/bert-pytorch/bert-base-uncased-vocab.txt')
# 加载bert模型,这个路径文件夹下有bert_config.json配置文件和model.bin模型权重文件
bert = BertModel.from_pretrained('E:/Projects/bert-pytorch/bert-base-uncased/')

s = "I'm not sure, this can work, lol -.-"

tokens = tokenizer.tokenize(s)
print("\\".join(tokens))
# "i\\'\\m\\not\\sure\\,\\this\\can\\work\\,\\lo\\##l\\-\\.\\-"
# 是否需要这样做?
# tokens = ["[CLS]"] + tokens + ["[SEP]"]

ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)])
print(ids.shape)
# torch.Size([1, 15])

result = bert(ids, output_all_encoded_layers=True)
print(result)

没问题,那么bert返回给我们了什么呢?

result = (
    [encoder_0_output, encoder_1_output, ..., encoder_11_output], 
    pool_output
)
  1. 因为我选择了参数output_all_encoded_layers=True,12层Transformer的结果全返回了,存在第一个列表中,每个encoder_output的大小为[batch_size, sequence_length, hidden_size];
  2. pool_out大小为[batch_size, hidden_size],pooler层的输出在论文中描述为:
    which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (CLS) to train on the Next-Sentence task (see BERT’s paper).
    也就是说,取了最后一层Transformer的输出结果的第一个单词[cls]的hidden states,其已经蕴含了整个input句子的信息了。
  3. 如果你用不上所有encoder层的输出,output_all_encoded_layers参数设置为Fasle,那么result中的第一个元素就不是列表了,只是encoder_11_output,大小为[batch_size, sequence_length, hidden_size]的张量,可以看作bert对于这句话的表示。

用bert微调我们的模型

将bert嵌入我们的模型即可。

class CustomModel(nn.Module):
    
    def __init__(self, bert_path, n_other_features, n_hidden):
        super().__init__()
        # 加载并冻结bert模型参数
        self.bert = BertModel.from_pretrained(bert_path)
        for param in self.bert.parameters():
            param.requires_grad = False
        self.output = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(768 + n_other_features, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, 1)
        )
    def forward(self, seqs, features):
        _, pooled = self.bert(seqs, output_all_encoded_layers=False)
        concat = torch.cat([pooled, features], dim=1)
        logits = self.output(concat)
        return logits

测试:

s = "I'm not sure, this can work, lol -.-"
​
tokens = tokenizer.tokenize(s)
ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)])
# print(ids)
# tensor([[1045, 1005, 1049, 2025, 2469, 1010, 2023, 2064, 2147, 1010, 8840, 2140,
#         1011, 1012, 1011]])
​
model = CustomModel('你的路径/bert-base-uncased/',10, 512)
outputs = model(ids, torch.rand(1, 10))
# print(outputs)
# tensor([[0.1127]], grad_fn=<AddmmBackward>)

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/175034.html原文链接:https://javaforall.net

(0)
上一篇 2022年8月28日 下午4:16
下一篇 2022年8月28日 下午4:16


相关推荐

  • TCP三次握手详解及释放连接过程[通俗易懂]

    TCP三次握手详解及释放连接过程[通俗易懂]TCP在传输之前会进行三次沟通,一般称为“三次握手”,传完数据断开的时候要进行四次沟通,一般称为“四次挥手”。两个序号和三个标志位:  (1)序号:seq序号,占32位,用来标识从TCP源端向目的端发送的字节流,发起方发送数据时对此进行标记。  (2)确认序号:ack序号,占32位,只有ACK标志位为1时,确认序号字段才有效,ack=seq+1。  (3)标志位:共6个,即URG、AC…

    2022年6月13日
    37
  • 中缀表达式转换为后缀表达式(C语言代码+详解)

    中缀表达式转换为后缀表达式(C语言代码+详解)中缀表达式转换为后缀表达式1.创建栈2.从左向右顺序获取中缀表达式a.数字直接输出b.运算符情况一:遇到左括号直接入栈,遇到右括号将栈中左括号之后入栈的运算符全部弹栈输出,同时左括号出栈但是不输出。情况二:遇到乘号和除号直接入栈,直到遇到优先级比它更低的运算符,依次弹栈。情况三:遇到加号和减号,如果此时栈空,则直接入栈,否则,将栈中优先级高的运算符依次弹栈(注意:加号和减号属于同一个…

    2022年6月16日
    28
  • 计算机dll修复工具,DLL修复工具哪个好?五款修复能力强推荐

    计算机dll修复工具,DLL修复工具哪个好?五款修复能力强推荐为什么会用到dll修复工具呢?因为我们在打开某些程序或者软件的时候会提示找不到某某.dll文件,关键是这些dll文件还不一样,我们去网上下载这些dll文件结果显示跟系统的版本不一致,反正就是各种麻烦,自己去找又费时又费力,而且往往对于有些游戏来说,修补了某一个dll又提示缺少另一个dll文件,这些其实可能都是系统本身太精简或者没有安装一些依赖软件导致的,这时候你完全不需要手动去找这些dll文件,只…

    2022年5月30日
    70
  • MessageDigest详解

    MessageDigest详解一、概述java.security.MessageDigest类用于为应用程序提供信息摘要算法的功能,如MD5或SHA算法。简单点说就是用于生成散列码。信息摘要是安全的单向哈希函数,它接收任意大小的数据,输出固定长度的哈希值。关于信息摘要和散列码请参照《数字证书简介》MessageDigest 通过其getInstance系列静态函数来进行实例化和初始化。MessageDigest对象通…

    2022年6月29日
    28
  • C语言程序设计第五版 谭浩强 第五版课后答案

    C语言程序设计第五版 谭浩强 第五版课后答案谭浩强C语言程序设计第五版第4章课后答案3.求两个正整数m和n,求其最大公约数和最小公倍数。#include<stdio.h>voidmain(){ intm,n,t,i,a=1; scanf(“%d%d”,&m,&n); if(m<n) { t=m; m=n; n=t; } for(i…

    2022年6月14日
    45
  • [股票预测]股票历史数据获取[通俗易懂]

    [股票预测]股票历史数据获取[通俗易懂]一、编程环境准备第一步:安装Anaconda3;第二步:安装工具包Pandas、tusharepipinstallPandaspipinstalltushare第三步:查看Pandas、tushare版本piplistpandas1.2.4tushare1.2.64二、股票历史行情数据提取2.1获取近3年个股日线交易数据通过参数设置获取日k线、周k线、月k线,…

    2022年6月24日
    40

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号