CBOW 更新[通俗易懂]

CBOW 更新[通俗易懂]代码:importtorchimporttorch.nnasnnimportnumpyasnpdefmake_context_vector(context,word_to_ix):idxs=[word_to_ix[w]forwincontext]returntorch.tensor(idxs,dtype=torch.long)…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

代码:

import torch
import torch.nn as nn
import numpy as np


def make_context_vector(context, word_to_ix):
    idxs = [word_to_ix[w] for w in context]
    return torch.tensor(idxs, dtype=torch.long)


def get_index_of_max(input):
    index = 0
    for i in range(1, len(input)):
        if input[i] > input[index]:
            index = i
    return index


def get_max_prob_result(input, ix_to_word):
    return ix_to_word[get_index_of_max(input)]


CONTEXT_SIZE = 2  # 2 words to the left, 2 to the right
EMDEDDING_DIM = 100

word_to_ix = {}
ix_to_word = {}

raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells.""".split()

# By deriving a set from `raw_text`, we deduplicate the array
vocab = set(raw_text)
vocab_size = len(vocab)

for i, word in enumerate(vocab):
    word_to_ix[word] = i
    ix_to_word[i] = word

data = []
for i in range(2, len(raw_text) - 2):
    context = [raw_text[i - 2], raw_text[i - 1],
               raw_text[i + 1], raw_text[i + 2]]
    target = raw_text[i]
    data.append((context, target))


class CBOW(torch.nn.Module):

    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()

        # out: 1 x emdedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)

        self.linear1 = nn.Linear(embedding_dim, 128)

        self.activation_function1 = nn.ReLU()

        # out: 1 x vocab_size
        self.linear2 = nn.Linear(128, vocab_size)

        self.activation_function2 = nn.LogSoftmax(dim=-1)

    def forward(self, inputs):
        embeds = sum(self.embeddings(inputs)).view(1, -1)
        out = self.linear1(embeds)
        out = self.activation_function1(out)
        out = self.linear2(out)
        out = self.activation_function2(out)
        return out

    def get_word_emdedding(self, word):
        word = torch.LongTensor([word_to_ix[word]])
        return self.embeddings(word).view(1, -1)


model = CBOW(vocab_size, EMDEDDING_DIM)

loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

for epoch in range(50):
    total_loss = 0
    for context, target in data:
        context_vector = make_context_vector(context, word_to_ix)
        model.zero_grad()
        log_probs = model(context_vector)
        loss = loss_function(log_probs, torch.tensor([word_to_ix[target]], dtype=torch.long))
        loss.backward()
        optimizer.step()

        total_loss += loss.data

# ====================== TEST
context = ['People', 'create', 'to', 'direct']
context_vector = make_context_vector(context, word_to_ix)
a = model(context_vector).data.numpy()
print('Raw text: {}\n'.format(' '.join(raw_text)))
print('Context: {}\n'.format(context))
print('Prediction: {}'.format(get_max_prob_result(a[0], ix_to_word)))

结果:

Context: ['People', 'create', 'to', 'direct']

Prediction: programs

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/197603.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • log4j的配置ConversionPattern详细讲解[通俗易懂]

    log4j的配置ConversionPattern详细讲解[通俗易懂]原文来自https://blog.csdn.net/reserved_person/article/details/52849505感谢大佬先写下我一直没找到的ConversionPattern里面参数代表的详细含义参数 说明 例子 %c 列出logger名字空间的全称,如果加上{<层数>}表示列出从最内层算起的指定层数的名字空间 log4j配置文件…

    2022年8月22日
    8
  • LoadRunner 压力测试

    LoadRunner 压力测试一、LoadRunner安装1.复制一下地址,然后打开迅雷,新建,选择一个磁盘大的空间,显示4.02G的ISO文件http://www.genilogix.com/downloads/loadrunner/loadrunner-11.isohttp://h30302.www3.hp.com/prdownloads/Software_HP_LoadRunner_11.00_Sim_Chines

    2022年7月18日
    14
  • 写了很久,这是一份最适合/贴切普通大众/科班/非科班的『学习路线』

    写了很久,这是一份最适合/贴切普通大众/科班/非科班的『学习路线』说实话,对于学习路线这种文章我一般是不写的,大家看我的文章也知道,我是很少写建议别人怎么样怎么样的文章,更多的是,写自己的真实经历,然后供大家去参考,这样子,我内心也比较踏实,也不怕误导他人。但是,最近好多人问我学习路线,而且很多大一大二的,说自己很迷茫,看到我那篇普普通通,我的三年大学之后很受激励,觉得自己也能行,(是的,别太浪,你一定能行)希望我能给他个学习路线,说…

    2022年7月16日
    20
  • 程序员面试宝典——第6章

    程序员面试宝典——第6章1 宏定义 define 基本知识 defineSECOND PER YEAR 60 60 24 365 UL 宏定义只是定义 不牵扯计算 defineMIN A B A lt B A B 2 constint nbsp b 500 constint a amp b const 修饰指针所指向的变量 指针的内容为常量 intconst a amp b const 修

    2025年8月18日
    3
  • Android Sdk版本、Support包版本及常用框架最新版本汇总

    Android Sdk版本、Support包版本及常用框架最新版本汇总1.SDKVerion数据来源于维基百科,和一篇博客Api版本号代号发布时间主要更新内容11.0无2008-09-23Web浏览器显示,短信,媒体播放器,相机,Wifi及蓝牙支持21.1PetitFour(花式小蛋糕)2009-02-09邮件中保存附件31….

    2022年5月29日
    52
  • 在IDEA中实战Git「建议收藏」

    在IDEA中实战Git「建议收藏」工作中多人使用版本控制软件协作开发,常见的应用场景归纳如下:假设小组中有两个人,组长小张,组员小袁场景一:小张创建项目并提交到远程Git仓库场景二:小袁从远程Git仓库上获取项目源码场景三:小袁修改了部分源码,提交到远程仓库场景四:小张从远程仓库获取小袁的提交场景五:小袁接受了一个新功能的任务,创建了一个分支并在分支上开发场景六:小袁把分支提交到远程Git仓库场景七…

    2022年6月29日
    38

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号