博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
nnet3的代码分析
阅读量:5339 次
发布时间:2019-06-15

本文共 3761 字,大约阅读时间需要 12 分钟。

nnet3/nnet-common.h

定义了Index(n, t, x)三元组,表示第nbatch中第t帧。

并声明了关于IndexCindex的一些读写操作。

   

nnet3/nnet-nnet.h

声明了NetworkNode(主要包含其类型以及索引信息)

声明了Nnetnnet3网络类)

private:

//网络中的组件名列表

std::vector<std::string> component_names_;

//网络中实际的组件指针列表,同一组件可能出现多次

std::vector<Component*> components_;

//网络中结点名列表,即:inputscomponents以及outputs

//同一组件名会出现两次:foo-inputfoo

//因为foo-input有其自己的NetworkNode索引

std::vector<std::string> node_names_;

//网络中实际的结点指针列表

std::vector<NetworkNode> nodes_;

以及关于以上数据成员的实用函数。

   

nnet3/nnet-component-itf.h

Componentitfinterface,接口)

class Component

主要包含以下函数:

Propagate //正向传播

Backprop //反向传播

StoreStats //储存平均激活值、非线性函数微分平均值

ZeroStats //stats清零

GetInputIndexes //只适用于非简单组件

IsComputable //只适用于非简单组件

ReorderIndexes //只适用于非简单组件

以及实用函数

class RandomComponent: public Component

随机数生成的组件

class UpdatableComponent: public Component

参数扰动率

学习率

学习率因子

实际学习率(实际学习率=学习率*学习率因子)

冻结自然梯度更新

每个minibatch最大参数变换率(NnetTrainerL2正则化的形式使用)

标准L2正则化参数

的设定、修改、查询

class NonlinearComponent: public Component

由于该类不修改特征维数,因子该类是sigmoidsoftmaxReLU的基类

该类

储存激活平均值

储存训练中的微分

负责模型初始化

负责IO

nnet3/nnet-simple-component.h

class PnormComponent: public Component

p-norm的公式:

对维数为intput_dim的输入进行降维,输出维数为output_dim

PropagateBackprop函数十分简单,具体关于p-norm单元的实现位于

kaldi::CuMatrixBase::GroupPnorm

Kaldi::CuMatrixBase::DiffGroupPnorm

class DropoutComponent : public RandomComponent

DropoutComponent组件对输入以dropout比例随机置零,而梯度只在非零的输入处进行反向传播。通常只在训练期间使用此组件,但不在测试时间使用

Dropout: A Simple Way to Prevent Neural Networks from Overfitting"

   

Propogate()

//初始化一个元素取值范围为[0,1]的向量y

const_cast<CuRand<BaseFloat>&>(random_generator_).RandUniform(out);

out->Add(-dropout);

out->ApplyHeaviside();

out->MulElements(in);

   

通过设置dropout_per_frame_,可以以帧的元素为单位dropout:

[[0,1,1,1],[1,0,1,1],[1,1,0,1],[1,1,1,0],[1,1,1,0]]

或帧为单位进行随机丢弃:

[[1,1,1,1],[0,0,0,0],[0,0,0,0],[1,1,1,1],[0,0,0,0]]

class ElementwiseProductComponent: public Component

点乘组件,用于降维

对于10维输入向量

(0.7,0.5,1.0,0.2,0.9,0.0,0.3,0.1,0.6,0.8)

假设输出维数为5,则10/5=2,两两相乘:

(0.7*0.5,1.0*0.2,0.9*0.0,0.3*0.1,0.6*0.8)

结果为

(0.35,0.20,0.0,0.03,0.48)

class SigmoidComponent: public NonlinearComponent

   

class TanhComponent: public NonlinearComponent

   

class RectifiedLinearComponent: public NonlinearComponent

   

class AffineComponent: public UpdatableComponent

   

class BlockAffineComponent : public UpdatableComponent

   

class RepeatedAffineComponent: public UpdatableComponent

   

class NaturalGradientRepeatedAffineComponent: public RepeatedAffineComponent

   

class SoftmaxComponent: public NonlinearComponent

Softmax损失函数(归一化指数函数):

其中o是输出向量

Backprop()

对于softmax函数的微分,令:

该函数的雅可比矩阵为:

令输出向量微分为e,输入向量微分为d,有:

   

nnet3/nnet-computation.h

负责实际的计算。

声明了ComputationRequestCommandTypeNnetComputation等类。

struct ComputationRequest

//计算需要的输入

std::vector<IoSpecification> inputs;

//计算预期的输出

std::vector<IoSpecification> outputs;

以及关于以上数据成员的实用函数

enum CommandType

神经网络计算类型,如:

kPropagate

kBackprop

kAllocMatrix

struct NnetComputation

编译后的神经网络具体计算特定步骤

给定NnetComputationRequest

就可编译得到该结构体

数据成员包括:

(子)矩阵信息及其索引(使用索引而不存储实际的矩阵)

矩阵

计算类型(CommandType

计算所依赖的输入输出Index

nnet3/nnet-analyze.h

检测计算是否能有效进行。

主要的类:

class ComputationAnalysis

private:

const NnetComputation &computation_;

const Analyzer &analyzer_;

ComputationVariables variables;

std::vector<CommandAttributes> command_attributes;

std::vector<std::vector<Access> > variable_accesses;

std::vector<MatrixAccesses> matrix_accesses;

成员函数:

访问索引s的第一个非初始化指令

访问索引s的第一个指令

访问索引s的最后一个指令

访问索引s的最后一个写指令

访问索引s的无效指令

访问矩阵索引m的第一个非初始化指令

访问矩阵索引m的最后一个指令

class ComputationChecker

ComputationAnalysis类似

主要检测:

维数一致性检测

未定义变量读取检测

读写冲突检测(是否是写完再读)

矩阵访问有效性检测

矩阵压缩检测

nnet3/nnet-example.h

struct NnetIo

std::vector<Index> indexes;

GeneralMatrix features;

特征(以及后验)的读写

struct NnetExample

//minibatch结构体

std::vector<NnetIo> io;

及其实用函数

以及一些关于NnetExample的比较、哈希等函数

   

   

转载于:https://www.cnblogs.com/JarvanWang/p/9152625.html

你可能感兴趣的文章
android 中对apache httpclient及httpurlconnection的选择
查看>>
1057. Stack (30) - 树状数组
查看>>
Charles使用三:设置代理
查看>>
linux下安装apache2.2.27
查看>>
C# 模式窗口下更新进度条
查看>>
git如何列出分支之间的差异commit
查看>>
tomcat配置证书
查看>>
手机端页面可以左右轻微拖动的bug
查看>>
关于Flume以及Kafka理解
查看>>
数据的前后台调用
查看>>
关于CS50课程
查看>>
Java基础之面向对象,类的创建和对象的实例化。
查看>>
swift泛型的5个要点和代码
查看>>
js中的typeof
查看>>
PHP:Deprecated: Function set_magic_quotes_runtime() is deprecated 错误
查看>>
来自网易云的黑科技,带尖角的div......
查看>>
十条不错的编成观点
查看>>
SQL Server操作实例
查看>>
spring-boot-单元测试参数数
查看>>
Python 中的线程与进程(四)
查看>>