链式求导法则公式_链式法则求导基础题

链式求导法则公式_链式法则求导基础题原题链接“计算图”(computational graph)是现代深度学习系统的基础执行引擎,提供了一种表示任意数学表达式的方法,例如用有向无环图表示的神经网络。 图中的节点表示基本操作或输入变量,边表示节点之间的中间值的依赖性。 例如,下图就是一个函数 ( 的计算图。现在给定一个计算图,请你根据所有输入变量计算函数值及其偏导数(即梯度)。 例如,给定输入,,上述计算图获得函数值 (;并且根据微分链式法则,上图得到的梯度 ∇。知道你已经把微积分忘了,所以这里只要求你处理几个简单的算子:加法、减法、乘

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

原题链接

“计算图”(computational graph)是现代深度学习系统的基础执行引擎,提供了一种表示任意数学表达式的方法,例如用有向无环图表示的神经网络。 图中的节点表示基本操作或输入变量,边表示节点之间的中间值的依赖性。 例如,下图就是一个函数 ( 的计算图。

在这里插入图片描述

现在给定一个计算图,请你根据所有输入变量计算函数值及其偏导数(即梯度)。 例如,给定输入,,上述计算图获得函数值 (;并且根据微分链式法则,上图得到的梯度 ∇。

知道你已经把微积分忘了,所以这里只要求你处理几个简单的算子:加法、减法、乘法、指数(e​x​​,即编程语言中的 exp(x) 函数)、对数(ln,即编程语言中的 log(x) 函数)和正弦函数(sin,即编程语言中的 sin(x) 函数)。

友情提醒:

常数的导数是 0;x 的导数是 1;e​x​​ 的导数还是 e​x​​;ln 的导数是 1;sin 的导数是 cos。
回顾一下什么是偏导数:在数学中,一个多变量的函数的偏导数,就是它关于其中一个变量的导数而保持其他变量恒定。在上面的例子中,当我们对 x​1​​ 求偏导数 / 时,就将 x​2​​ 当成常数,所以得到 ln 的导数是 1,x​1​​x​2​​ 的导数是 x​2​​,sin 的导数是 0。
回顾一下链式法则:复合函数的导数是构成复合这有限个函数在相应点的导数的乘积,即若有 (,(,则 /。例如对 sin 求导,就得到 cos。
如果你注意观察,可以发现在计算图中,计算函数值是一个从左向右进行的计算,而计算偏导数则正好相反。

输入格式:
输入在第一行给出正整数 N(≤),为计算图中的顶点数。

以下 N 行,第 i 行给出第 i 个顶点的信息,其中 ,。第一个值是顶点的类型编号,分别为:

0 代表输入变量
1 代表加法,对应 x​1​​+x​2​​
2 代表减法,对应 x​1​​−x​2​​
3 代表乘法,对应 x​1​​×x​2​​
4 代表指数,对应 e​x​​
5 代表对数,对应 ln
6 代表正弦函数,对应 sin
对于输入变量,后面会跟它的双精度浮点数值;对于单目算子,后面会跟它对应的单个变量的顶点编号(编号从 0 开始);对于双目算子,后面会跟它对应两个变量的顶点编号。

题目保证只有一个输出顶点(即没有出边的顶点,例如上图最右边的 -),且计算过程不会超过双精度浮点数的计算精度范围。

输出格式:
首先在第一行输出给定计算图的函数值。在第二行顺序输出函数对于每个变量的偏导数的值,其间以一个空格分隔,行首尾不得有多余空格。偏导数的输出顺序与输入变量的出现顺序相同。输出小数点后 3 位。

输入样例:

7
0 2.0
0 5.0
5 0
3 0 1
6 1
1 2 3
2 5 4
输出样例:
11.652
5.500 1.716

题解
将每个节点的输入节点编号存入到节点结构体中,然后先正向bfs求每个节点的输出值,然后再方向求每个节点的导数,注意求导数的时候每个节点存储的导数是其输出变量的导数,求解的时候应该按照不同路径的导数相加,同一路径上的导数相乘。
无论是正向还是方向均应按照拓扑序求解

#include<bits/stdc++.h>
#include<cmath>
#define x first
#define y second
#define send string::npos
#define lowbit(x) (x&(-x))
#define left(x) x<<1
#define right(x) x<<1|1
using namespace std;
typedef long long ll;
typedef pair<int,int> PII;
typedef struct Node * pnode;
const int N = 1e6 + 10;
const int M = 3 * N;
const int INF = 0x3f3f3f3f;
const ll LINF = 0x3f3f3f3f3f3f3f3f;
const int Mod = 1e9;
int out[N],in[N];
struct Node{ 
   
    double v,f;    //v代表此节点输出值,f代表输出值导数
    int la,lb;
    int op;
}node[N];
int head[N],cnt;
int q[N],tt = 0,hh = 0;
vector<int>s;   //起始节点
struct Edge{ 
   
    int v,next;
}edge[2 * M];
void add(int u,int v){ 
   
    edge[cnt].v = v;
    edge[cnt].next = head[u];
    head[u] = cnt ++;
}
double op123(int t,double a,double b){ 
   
    if(t == 1)return a + b;
    if(t == 2)return a - b;
    if(t == 3)return a * b;
}
double op456(int t,double a){ 
   
    if(t == 4)return exp(a);
    if(t == 5)return log(a);
    if(t == 6)return sin(a);
}
void bfs(){ 
   
    for(int i = 0;i < s.size();i ++)q[tt ++] = s[i];
    while(hh < tt){ 
   
        int t = q[hh ++];
// cout<<t<<endl;
        double a = node[node[t].la].v,b = node[node[t].lb].v;
        int type = node[t].op;
        if(type == 1 || type == 2 || type == 3)node[t].v = op123(type,a,b);
        else if(type != 0)node[t].v = op456(type,a);
// cout<<t<<" "<<type<<" "<<node[t].v<<endl;
        for(int i = head[t];~i;i = edge[i].next){ 
   
            int v = edge[i].v;
            in[v] --;
            if(in[v] == 0)	//如果此处不按拓扑序,则会产生大量的重复节点
                q[tt ++] = v;
        }
    }
}
void top(int root){ 
   
    hh = tt = 0;
    q[tt ++] = root;
    while(hh < tt){ 
   
        int t = q[hh ++];
// cout<<t<<endl;
        if(node[t].op == 1 || node[t].op == 2 || node[t].op == 3){ 
   
            if(node[t].op == 1){ 
   
                node[node[t].la].f += (1 * node[t].f);
                node[node[t].lb].f += (1 * node[t].f);
            }else if(node[t].op == 2){ 
   
                node[node[t].la].f += (1 * node[t].f);
                node[node[t].lb].f += (-1 * node[t].f);
            }else{ 
   
                node[node[t].la].f += (node[node[t].lb].v * node[t].f);
                node[node[t].lb].f += (node[node[t].la].v * node[t].f);
            }
            out[node[t].la] --,out[node[t].lb] --;
            if(out[node[t].la] == 0)q[tt ++] = node[t].la;
            if(out[node[t].lb] == 0)q[tt ++] = node[t].lb;
        }
        else if(node[t].op != 0){ 
   
            if(node[t].op == 4)node[node[t].la].f += (node[t].v * node[t].f);
            else if(node[t].op == 5)node[node[t].la].f += (node[t].f / node[node[t].la].v) ;
            else node[node[t].la].f += (cos(node[node[t].la].v) * node[t].f);
            out[node[t].la] --;
            if(out[node[t].la] == 0)q[tt ++] = node[t].la;
        }
    }
}
int main(){ 
   
    memset(head,-1,sizeof head);
    int n,t,a,b;
    double x;
    cin>>n;
    for(int i = 0;i < n;i ++){ 
   
        cin>>t;
        if(t == 0){ 
   
            cin>>x;
            s.push_back(i);
            node[i].v = x;
        }
        else if(t == 1 || t == 2 || t == 3){ 
   
            cin>>a>>b;
            add(a,i);
            add(b,i);
            out[a] ++,out[b] ++;
            in[i] ++,in[i] ++;
            node[i].la = a,node[i].lb = b;
        }else if(t == 4 || t == 5 || t == 6){ 
   
            cin>>a;
            add(a,i);
            out[a] ++;
            in[i] ++;
            node[i].la = a;
        }
        node[i].op = t;
    }
    int root = -1;
    for(int i = 0;i < n;i ++)
        if(!out[i])
            root = i;
// cout<<"root:"<<root<<endl;
    bfs();
    node[root].f = 1;		//最后一个节点的输出值导数应该为1
    top(root);
    printf("%.3f\n%.3f",node[root].v,node[s[0]].f);
    for(int i = 1;i < s.size();i ++)printf(" %.3f",node[s[i]].f);
    return 0;
}

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/168768.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 手机开发平台介绍[通俗易懂]

    手机开发平台介绍[通俗易懂]
    刚刚进入手机研发领域,为今后职业生涯规划,从网上搜了资料方便学习。
    手机客户端软件开发最大的困难就是平台不统一,手机开发平台太多。
    手机可分为智能手机开发和featherphone手机。开发平台可分为开放式平台和封闭式平台,开放式平台包括symbian、windowsmobile、linux、iPhone、Android、BlackBerry、j2me、brew等,支持手机应用程序通过OTA下载和安装;封闭式平台包括MTK、展讯、TI、飞利浦等。下面分别介绍。

    2022年8月12日
    6
  • 融合计费账务系统架构与核心功能的研究与实现

    融合计费账务系统架构与核心功能的研究与实现2006年初,融合计费账务系统的发展趋势及其重要性已得到业界的广泛关注,各电信运营商及开发商也开始了相应的讨论、研究和规划,北京联通(原北京网通)在业务和网络的发展驱动下,率先开始了融合计费账务系统的规划与建设,真正建设一个统一支撑大客户、商务客户和公众客户所有客户群,统一支撑北京联通电话、宽带、小灵通、互联网、专线及CP/SP业务等全业务及其灵活捆绑与组合营销,统一支撑在线…

    2025年6月17日
    0
  • snmptrap怎么发送_cmd运行nmap

    snmptrap怎么发送_cmd运行nmapSNMP简单网络管理协议,其中其支持的一个命令snmptrap命令,用于模拟向管理机发送trap消息。启动陷阱方法:snmptrapd-C-c/etc/snmp/snmptrapd.conf-Lf/var/log/net-snmptrap.log例如:snmptrap-v1-cpublic192.168.2.124.1.3.6.1.4.1.1192.168.2.12561…

    2022年8月20日
    21
  • 使用srvany.exe把程序安装成windows服务的方法

    使用srvany.exe把程序安装成windows服务的方法2019独角兽企业重金招聘Python工程师标准>>>…

    2022年5月30日
    47
  • 论计算机发展史及展望_策略单元培训心得

    论计算机发展史及展望_策略单元培训心得一种对计算机发展史展开研究的策略(3页)本资源提供全文预览,点击全文预览即可全文预览,如果喜欢文档就下载吧,查找使用更方便哦!9.9积分一种对计算机发展史展开研究的策略一种对计算机发展史展开研究的策略一种对计算机发展史展开研究的策略一、引言随着中国的开放,科学技术的国际交流日益深入,现代化意义上的计算机产品与技术被不断介绍并引入到国内,且在短时间内取得了迅.L.猛的发展。然而,作为…

    2022年10月18日
    0
  • 纸的大小图解_常用纸张尺寸及示意图(A0,A1A3,A4,A5A8)数据源维基百科.PDF

    纸的大小图解_常用纸张尺寸及示意图(A0,A1A3,A4,A5A8)数据源维基百科.PDF常用纸张尺寸及示意图(A0,A1A3,A4,A5A8)数据源维基百科mFPCharging-常用纸张尺寸说明2011-12-09v2.0常用纸张尺寸及示意图(A0,A1…A3,A4,A5…A8)数据源:维基百科标准定义ISO216定义了A、B、C三个系列的纸张尺寸。C系列纸张尺寸主要使用于信封。ISO216的格式遵循着的比率;放在一起的两张纸有着…

    2022年6月20日
    64

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号