DyNet documentation¶
DyNet (formerly known as cnn) is a neural network library developed by Carnegie Mellon University and many others. It is written in C++ (with bindings in Python) and is designed to be efficient when run on either CPU or GPU, and to work well with networks that have dynamic structures that change for every training instance. For example, these kinds of networks are particularly important in natural language processing tasks, and DyNet has been used to build state-of-the-art systems for syntactic parsing, machine translation, morphological inflection, and many other application areas.
Read the documentation below to get started, and feel free to contact the dynet-users group with any questions (if you want to receive email make sure to select “all email” when you sign up). We greatly appreciate any bug reports and contributions, which can be made by filing an issue or making a pull request through the github page.
You can also read more technical details in our technical report. If you use DyNet for research, please cite this report as follows:
@article{dynet,
title={DyNet: The Dynamic Neural Network Toolkit},
author={Graham Neubig and Chris Dyer and Yoav Goldberg and Austin Matthews and Waleed Ammar and Antonios Anastasopoulos and Miguel Ballesteros and David Chiang and Daniel Clothiaux and Trevor Cohn and Kevin Duh and Manaal Faruqui and Cynthia Gan and Dan Garrette and Yangfeng Ji and Lingpeng Kong and Adhiguna Kuncoro and Gaurav Kumar and Chaitanya Malaviya and Paul Michel and Yusuke Oda and Matthew Richardson and Naomi Saphra and Swabha Swayamdipta and Pengcheng Yin},
journal={arXiv preprint arXiv:1701.03980},
year={2017}
}
DyNet can be installed according to the instructions below:
Building/Installing¶
How to build DyNet and link it with your programs
Prerequisites¶
DyNet relies on a number of external libraries including Boost, CMake, Eigen, and Mercurial (to install Eigen). Boost, CMake, and Mercurial can be installed from standard repositories.
For example on Ubuntu Linux:
sudo apt-get install libboost-all-dev cmake mercurial
Or on macOS, first make sure the Apple Command Line Tools are installed, then get Boost, CMake, and Mercurial with either homebrew or macports:
xcode-select --install
brew install boost cmake hg # Using homebrew.
sudo port install boost cmake mercurial # Using macports.
To compile DyNet you also need the development version of the Eigen library. If you use any of the released versions, you may get assertion failures or compile errors. If you don’t have Eigen installed already, you can get it easily using the following command:
hg clone https://bitbucket.org/eigen/eigen/ -r 346ecdb
cd eigen
mkdir build && cd build
cmake ..
make install # sudo permissions might be necessary on Linux.
cd ../..
The -r NUM specified a revision number that is known to work. Adventurous users can remove it and use the very latest version, at the risk of the code breaking / not compiling. On macOS, you can install the latest development of Eigen using Homebrew:
brew install --HEAD eigen
Building¶
To get and build DyNet, clone the repository
git clone https://github.com/clab/dynet.git
then enter the directory and use `cmake
<http://www.cmake.org/>`__
to generate the makefiles
cd dynet
mkdir build
cd build
cmake .. -DEIGEN3_INCLUDE_DIR=/path/to/eigen
Then compile, where “2” can be replaced by the number of cores on your machine
make -j 2
To see that things have built properly, you can run
./examples/train_xor
which will train a multilayer perceptron to predict the xor function.
Compiling/linking external programs¶
When you want to use DyNet in an external program, you will need to add
the dynet
directory to the compile path:
-I/path/to/dynet
and link with the DyNet library:
-L/path/to/dynet/build/dynet -ldynet
Debugging build problems¶
If you have a build problem and want to debug, please run
make clean
make VERBOSE=1 &> make.log
then examine the commands in the make.log
file to see if anything
looks fishy. If you would like help, send this make.log
file via the
“Issues” tab on GitHub, or to the dynet-users mailing list.
GPU/MKL support and build options¶
GPU (CUDA) support¶
DyNet supports running programs on GPUs with CUDA. If you have CUDA
installed, you can build DyNet with GPU support by adding
-DBACKEND=cuda
to your cmake options. This will result in three
libraries named “libdynet” and “libgdynet” being
created. When you want to run a program on CPU, you can link to the
“libdynet” library as shown above. When you want to run a program on
GPU, you can link to the “libgdynet” library.
-L/path/to/dynet/build/dynet -lgdynet
(Eventually you will be able to use a single library to run on either CPU or GPU, but this is not fully implemented yet.)
MKL support¶
DyNet can leverage Intel’s MKL library to speed up computation on the CPU. As an example, we’ve seen 3x speedup in seq2seq training when using MKL. To use MKL, include the following cmake option:
-DMKL=TRUE
If CMake is unable to find MKL automatically, try setting MKL_ROOT, such as
-DMKL_ROOT="/path/to/MKL"
One common install location is /opt/intel/mkl/
.
If either MKL or MKL_ROOT are set, CMake will look for MKL.
By default, MKL will use all CPU cores. You can control how many cores MKL uses by setting the environment variable MKL_NUM_THREADS to the desired number. The following is the total time to process 250 training examples running the example encdec (on a 6 core Intel Xeon E5-1650):
encdec.exe --dynet-seed 1 --dynet-mem 1000 train-hsm.txt dev-hsm.txt
+-----------------+------------+---------+
| MKL_NUM_THREADS | Cores Used | Time(s) |
+-----------------+------------+---------+
| <Without MKL> | 1 | 28.6 |
| 1 | 1 | 13.3 |
| 2 | 2 | 9.5 |
| 3 | 3 | 8.1 |
| 4 | 4 | 7.8 |
| 6 | 6 | 8.2 |
+-----------------+------------+---------+
As you can see, for this particular example, using MKL roughly doubles the speed of computation while still using only one core. Increasing the number of cores to 2 or 3 is quite beneficial, but beyond that there are diminishing returns or even slowdown.
Non-standard Boost location¶
DyNet requires Boost, and will find it if it is in the standard
location. If Boost is in a non-standard location, say $HOME/boost
,
you can specify the location by adding the following to your CMake
options:
-DBOOST_ROOT:PATHNAME=$HOME/boost -DBoost_LIBRARY_DIRS:FILEPATH=$HOME/boost/lib
-DBoost_NO_BOOST_CMAKE=TRUE -DBoost_NO_SYSTEM_PATHS=TRUE
Note that you will also have to set your LD_LIBRARY_PATH
to point to
the boost/lib
directory.
Note also that Boost must be compiled with the same compiler version as
you are using to compile DyNet.
Building for Windows¶
DyNet has been tested to build in Windows using Microsoft Visual Studio 2015. You may be able to build with MSVC 2013 by slightly modifying the instructions below.
First, install Eigen following the above instructions.
Second, install Boost for your compiler and platform. Follow the instructions for compiling Boost or just download the already-compiled binaries.
To generate the MSVC solution and project files, run cmake, pointing it to the location you installed Eigen and Boost (for example, at c:\libs\Eigen and c:\libs\boost_1_61_0):
mkdir build
cd build
cmake .. -DEIGEN3_INCLUDE_DIR=c:\libs\Eigen -DBOOST_ROOT=c:\libs\boost_1_61_0 -DBOOST_LIBRARYDIR=c:\libs\boost_1_61_0\lib64-msvc-14.0 -DBoost_NO_BOOST_CMAKE=ON -G"Visual Studio 14 2015 Win64"
This will generate dynet.sln and a bunch of *.vcxproj files (one for the DyNet library, and one per example). You should be able to just open dynet.sln and build all. Note: multi-process functionality is currently not supported in Windows, so the multi-process examples (`*-mp`) will not be included in the generated solution
The Windows build also supports CUDA with the latest version of Eigen (as of Oct 28, 2016), with the following code change:
- TensorDeviceCuda.h: Change sleep(1) to Sleep(1000)
Installing the Python DyNet module.¶
(for instructions on installing on a computer with GPU, see below)
Python bindings to DyNet are supported for both Python 2.x and 3.x.
TL;DR¶
(see below for the details)
# Installing Python DyNet:
pip install cython # if you don't have it already.
mkdir dynet-base
cd dynet-base
# getting dynet and eigen
git clone https://github.com/clab/dynet.git
hg clone https://bitbucket.org/eigen/eigen -r 346ecdb # -r NUM specified a known working revision
cd dynet
mkdir build
cd build
# without GPU support:
cmake .. -DEIGEN3_INCLUDE_DIR=../../eigen -DPYTHON=`which python`
# or with GPU support:
cmake .. -DEIGEN3_INCLUDE_DIR=../../eigen -DPYTHON=`which python` -DBACKEND=cuda
make -j 2 # replace 2 with the number of available cores
cd python
python setup.py install # or `python setup.py install --user` for a user-local install.
# this should suffice, but on some systems you may need to add the following line to your
# init files in order for the compiled .so files be accessible to Python.
# /path/to/dynet/build/dynet is the location in which libdynet.dylib resides.
export DYLD_LIBRARY_PATH=/path/to/dynet/build/dynet/:$DYLD_LIBRARY_PATH
Detailed Instructions¶
First, get DyNet:
cd $HOME
mkdir dynet-base
cd dynet-base
git clone https://github.com/clab/dynet.git
cd dynet
git submodule init # To be consistent with DyNet's installation instructions.
git submodule update # To be consistent with DyNet's installation instructions.
Then get Eigen:
cd $HOME
cd dynet-base
hg clone https://bitbucket.org/eigen/eigen/ -r 346ecdb
(-r NUM specifies a known working revision of Eigen. You can remove this in order to get the bleeding edge Eigen, with the risk of some compile breaks, and the possible benefit of added optimizations.)
We also need to make sure the cython
module is installed. (you can
replace pip
with your favorite package manager, such as conda
,
or install within a virtual environment)
pip install cython
To simplify the following steps, we can set a bash variable to hold where we have saved the main directories of DyNet and Eigen. In case you have gotten DyNet and Eigen differently from the instructions above and saved them in different location(s), these variables will be helpful:
PATH_TO_DYNET=$HOME/dynet-base/dynet/
PATH_TO_EIGEN=$HOME/dynet-base/eigen/
Compile DyNet.
This is pretty much the same process as compiling DyNet, with the
addition of the -DPYTHON=
flag, pointing to the location of your
Python interpreter.
If Boost is installed in a non-standard location, you should add the
corresponding flags to the cmake
commandline, see the DyNet
installation instructions page.
cd $PATH_TO_DYNET
PATH_TO_PYTHON=`which python`
mkdir build
cd build
cmake .. -DEIGEN3_INCLUDE_DIR=$PATH_TO_EIGEN -DPYTHON=$PATH_TO_PYTHON
make -j 2
Assuming that the cmake
command found all the needed libraries and
didn’t fail, the make
command will take a while, and compile DyNet
as well as the Python bindings. You can change make -j 2
to a higher
number, depending on the available cores you want to use while
compiling.
You now have a working Python binding inside of build/dynet
. To
verify this is working:
cd $PATH_TO_DYNET/build/python
python
then, within Python:
import dynet as dy
print dy.__version__
model = dy.Model()
In order to install the module so that it is accessible from everywhere in the system, run the following:
cd $PATH_TO_DYNET/build/python
python setup.py install --user
The --user
switch will install the module in your local
site-packages, and works without root privileges. To install the module
to the system site-packages (for all users), or to the current virtualenv
(if you are on one), run python setup.py install
without this switch.
You should now have a working python binding (the dynet
module).
Note however that the installation relies on the compiled DyNet library
being in $PATH_TO_DYNET/build/dynet
, so make sure not to move it
from there.
Now, check that everything works:
cd $PATH_TO_DYNET
cd examples/python
python xor.py
python rnnlm.py rnnlm.py
Alternatively, if the following script works for you, then your installation is likely to be working:
from dynet import *
model = Model()
If it doesn’t work and you get an error similar to the following:
ImportError: dlopen(/Users/sneharajana/.python-eggs/dyNET-0.0.0-py2.7-macosx-10.11-intel.egg-tmp/_dynet.so, 2): Library not loaded: @rpath/libdynet.dylib
Referenced from: /Users/sneharajana/.python-eggs/dyNET-0.0.0-py2.7-macosx-10.11-intel.egg-tmp/_dynet.so
Reason: image not found``
then you may need to run the following (and add it to your shell init files):
export DYLD_LIBRARY_PATH=/path/to/dynet/build/dynet/:$DYLD_LIBRARY_PATH
Usage¶
There are two ways to import the dynet module :
import dynet
imports dynet and automatically initializes the global dynet parameters with the command line arguments (see the documentation). The amount of memory allocated, GPU/CPU usage is fixed from there on.
import _dynet
# or
import _gdynet # For GPU
Imports dynet for CPU (resp. GPU) and doesn’t initialize the global parameters. These must be initialized manually before using dynet, using one of the following :
# Same as import dynet as dy
import _dynet as dy
dy.init()
# Same as import dynet as dy
import _dynet as dy
# Declare a DynetParams object
dyparams = dy.DynetParams()
# Fetch the command line arguments (optional)
dyparams.from_args()
# Set some parameters manualy (see the command line arguments documentation)
dyparams.set_mem(2048)
dyparams.set_random_seed(666)
dyparams.set_weight_decay(1e-7)
dyparams.set_shared_parameters(False)
dyparams.set_requested_gpus(1)
dyparams.set_gpu_mask([0,1,1,0])
# Initialize with the given parameters
dyparams.init() # or init_from_params(dyparams)
Anaconda Support¶
Anaconda is a popular package management system for Python. DyNet can be used from within an Anaconda environment, but be sure to activate the environment
source activate my_environment_name
then install some necessary packages as follows:
conda install gcc cmake boost cython
After this, the build process should be the same as normal.
Note that on some conda environments, people have reported build errors related to the interaction between the icu
and boost
packages. If you encounter this, try the solution in this comment.
Windows Support¶
You can also use Python on Windows by following similar steps to the above. For simplicity, we recommend using a Python distribution that already has Cython installed. The following has been tested to work:
- Install WinPython 2.7.10 (comes with Cython already installed).
- Run CMake as above with
-DPYTHON=/path/to/your/python.exe
. - Open a command prompt and set
VS90COMNTOOLS
to the path to your Visual Studio “Common7/Tools” directory. One easy way to do this is a command such as:
set VS90COMNTOOLS=%VS140COMNTOOLS%
- Open dynet.sln from this command prompt and build the “Release” version of the solution.
- Follow the rest of the instructions above for testing the build and installing it for other users
Note, currently only the Release version works.
GPU/MKL Support¶
Installing/running on GPU¶
For installing on a computer with GPU, first install CUDA. The following instructions assume CUDA is installed.
The installation process is pretty much the same, while adding the
-DBACKEND=cuda
flag to the cmake
stage:
cmake .. -DEIGEN3_INCLUDE_DIR=$PATH_TO_EIGEN -DPYTHON=$PATH_TO_PYTHON -DBACKEND=cuda
(if CUDA is installed in a non-standard location and cmake
cannot
find it, you can specify also
-DCUDA_TOOLKIT_ROOT_DIR=/path/to/cuda
.)
Now, build the Python modules (as above, we assume Cython is installed):
After running make -j 2
, you should have the files _dynet.so
and
_gdynet.so
in the build/python
folder.
As before, cd build/python
followed by
python setup.py install --user
will install the module.
In order to use the GPU support, you can either:
- Use
import _gdynet as dy
instead ofimport dynet as dy
- Or, (preferred),
import dynet
as usual, but use the commandline switch--dynet-gpu
or the GPU switches detailed here when invoking the program. This option lets the same code work with either the GPU or the CPU version depending on how it is invoked.
Running with MKL¶
If you’ve built DyNet to use MKL (using -DMKL
or -DMKL_ROOT
), Python sometimes has difficulty finding
the MKL shared libraries. You can try setting LD_LIBRARY_PATH
to point to your MKL library directory.
If that doesn’t work, try setting the following environment variable (supposing, for example,
your MKL libraries are located at /opt/intel/mkl/lib/intel64
):
export LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_def.so:/opt/intel/mkl/lib/intel64/libmkl_avx2.so:/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_intel_lp64.so:/opt/intel/mkl/lib/intel64/libmkl_intel_thread.so:/opt/intel/lib/intel64_lin/libiomp5.so
And get the basic information to create programs and use models:
DyNet Tutorial¶
C++ Tutorial¶
See the tutorials for the C++ version of DyNet
Basic Tutorial¶
An illustration of how models are trained (for a simple logistic regression model) is below:
First, we set up the structure of the model.
Create a model, and an SGD trainer to update its parameters.
Model mod;
SimpleSGDTrainer sgd(mod);
Create a “computation graph,” which will define the flow of information.
ComputationGraph cg;
Initialize a 1x3 parameter vector, and add the parameters to be part of the computation graph.
Expression W = parameter(cg, mod.add_parameters({1, 3}));
Create variables defining the input and output of the regression, and load them into the computation graph. Note that we don’t need to set concrete values yet.
vector<dynet::real> x_values(3);
Expression x = input(cg, {3}, &x_values);
dynet::real y_value;
Expression y = input(cg, &y_value);
Next, set up the structure to multiply the input by the weight vector, then run the output of this through a logistic sigmoid function logistic regression).
Expression y_pred = logistic(W*x);
Finally, we create a function to calculate the loss. The model will be optimized to minimize the value of the final function in the computation graph.
Expression l = binary_log_loss(y_pred, y);
We are now done setting up the graph, and we can print out its structure:
cg.print_graphviz();
Now, we perform a parameter update for a single example. Set the input/output to the values specified by the training data:
x_values = {0.5, 0.3, 0.7};
y_value = 1.0;
“forward” propagates values forward through the computation graph, and returns the loss.
dynet::real loss = as_scalar(cg.forward(l));
“backward” performs back-propagation, and accumulates the gradients of the parameters within the “Model” data structure.
cg.backward(l);
“sgd.update” updates parameters of the model that was passed to its constructor. Here 1.0 is the scaling factor that allows us to control the size of the update.
sgd.update(1.0);
Note that this very simple example that doesn’t cover things like memory
initialization, reading/writing models, recurrent/LSTM networks, or
adding biases to functions. The best way to get an idea of how to use
DyNet for real is to look in the example
directory, particularly
starting with the simplest xor
example.
Python Tutorial¶
Guided examples in Python can be found below :
Working with the python DyNet package¶
The DyNet package is intended for training and using neural networks, and is particularly suited for applications with dynamically changing network structures. It is a python-wrapper for the DyNet C++ package.
In neural network packages there are generally two modes of operation:
- Static networks, in which a network is built and then being fed with different inputs/outputs. Most NN packages work this way.
- Dynamic networks, in which a new network is built for each training example (sharing parameters with the networks of other training examples). This approach is what makes DyNet unique, and where most of its power comes from.
We will describe both of these modes.
Package Fundamentals¶
The main piece of DyNet is the ComputationGraph
, which is what
essentially defines a neural network. The ComputationGraph
is
composed of expressions, which relate to the inputs and outputs of the
network, as well as the Parameters
of the network. The parameters
are the things in the network that are optimized over time, and all of
the parameters sit inside a Model
. There are trainers
(for
example SimpleSGDTrainer
) that are in charge of setting the
parameter values.
We will not be using the ComputationGraph
directly, but it is there
in the background, as a singleton object. When dynet
is imported, a
new ComputationGraph
is created. We can then reset the computation
graph to a new state by calling renew_cg()
.
Static Networks¶
The life-cycle of a DyNet program is: 1. Create a Model
, and
populate it with Parameters
. 2. Renew the computation graph, and
create Expression
representing the network (the network will include
the Expression
s for the Parameters
defined in the model). 3.
Optimize the model for the objective of the network.
As an example, consider a model for solving the “xor” problem. The network has two inputs, which can be 0 or 1, and a single output which should be the xor of the two inputs. We will model this as a multi-layer perceptron with a single hidden node.
Let \(x = x_1, x_2\) be our input. We will have a hidden layer of 8 nodes, and an output layer of a single node. The activation on the hidden layer will be a \(\tanh\). Our network will then be:
\(\sigma(V(\tanh(Wx+b)))\)
Where \(W\) is a \(8 \times 2\) matrix, \(V\) is an \(8 \times 1\) matrix, and \(b\) is an 8-dim vector.
We want the output to be either 0 or 1, so we take the output layer to be the logistic-sigmoid function, \(\sigma(x)\), that takes values between \(-\infty\) and \(+\infty\) and returns numbers in \([0,1]\).
We will begin by defining the model and the computation graph.
In [1]:
# we assume that we have the dynet module in your path.
# OUTDATED: we also assume that LD_LIBRARY_PATH includes a pointer to where libcnn_shared.so is.
from dynet import *
In [2]:
# create a model and add the parameters.
m = Model()
pW = m.add_parameters((8,2))
pV = m.add_parameters((1,8))
pb = m.add_parameters((8))
renew_cg() # new computation graph. not strictly needed here, but good practice.
# associate the parameters with cg Expressions
W = parameter(pW)
V = parameter(pV)
b = parameter(pb)
In [3]:
#b[1:-1].value()
b.value()
Out[3]:
[-0.5920619964599609,
-0.4818088114261627,
-0.011437613517045975,
-0.7547096610069275,
0.2887613773345947,
-0.39806437492370605,
-0.8494511246681213,
0.295582115650177]
The first block creates a model and populates it with parameters. The
second block creates a computation graph and adds the parameters to it,
transforming them into Expression
s. The need to distinguish model
parameters from “expressions” will become clearer later.
We now make use of the W and V expressions, in order to create the complete expression for the network.
In [4]:
x = vecInput(2) # an input vector of size 2. Also an expression.
output = logistic(V*(tanh((W*x)+b)))
In [5]:
# we can now query our network
x.set([0,0])
output.value()
Out[5]:
0.706532895565033
In [6]:
# we want to be able to define a loss, so we need an input expression to work against.
y = scalarInput(0) # this will hold the correct answer
loss = binary_log_loss(output, y)
In [7]:
x.set([1,0])
y.set(0)
print loss.value()
y.set(1)
print loss.value()
1.25551486015
0.335373580456
Training¶
We now want to set the parameter weights such that the loss is minimized.
For this, we will use a trainer object. A trainer is constructed with respect to the parameters of a given model.
In [8]:
trainer = SimpleSGDTrainer(m)
To use the trainer, we need to: * call the ``forward_scalar``
method of ComputationGraph
. This will run a forward pass through the
network, calculating all the intermediate values until the last one
(loss
, in our case), and then convert the value to a scalar. The
final output of our network must be a single scalar value. However,
if we do not care about the value, we can just use cg.forward()
instead of cg.forward_sclar()
. * call the ``backward`` method
of ComputationGraph
. This will run a backward pass from the last
node, calculating the gradients with respect to minimizing the last
expression (in our case we want to minimize the loss). The gradients are
stored in the model, and we can now let the trainer
take care of the
optimization step. * call ``trainer.update()`` to optimize the
values with respect to the latest gradients.
In [9]:
x.set([1,0])
y.set(1)
loss_value = loss.value() # this performs a forward through the network.
print "the loss before step is:",loss_value
# now do an optimization step
loss.backward() # compute the gradients
trainer.update()
# see how it affected the loss:
loss_value = loss.value(recalculate=True) # recalculate=True means "don't use precomputed value"
print "the loss after step is:",loss_value
the loss before step is: 0.335373580456
the loss after step is: 0.296859383583
The optimization step indeed made the loss decrease. We now need to run
this in a loop. To this end, we will create a training set
, and
iterate over it.
For the xor problem, the training instances are easy to create.
In [10]:
def create_xor_instances(num_rounds=2000):
questions = []
answers = []
for round in xrange(num_rounds):
for x1 in 0,1:
for x2 in 0,1:
answer = 0 if x1==x2 else 1
questions.append((x1,x2))
answers.append(answer)
return questions, answers
questions, answers = create_xor_instances()
We now feed each question / answer pair to the network, and try to minimize the loss.
In [11]:
total_loss = 0
seen_instances = 0
for question, answer in zip(questions, answers):
x.set(question)
y.set(answer)
seen_instances += 1
total_loss += loss.value()
loss.backward()
trainer.update()
if (seen_instances > 1 and seen_instances % 100 == 0):
print "average loss is:",total_loss / seen_instances
average loss is: 0.730996069312
average loss is: 0.686455376148
average loss is: 0.614968097508
average loss is: 0.529396591447
average loss is: 0.454356552631
average loss is: 0.39492503399
average loss is: 0.348310606687
average loss is: 0.311234809482
average loss is: 0.281200638587
average loss is: 0.256437818106
average loss is: 0.235696636033
average loss is: 0.218082525641
average loss is: 0.202943060785
average loss is: 0.189793206944
average loss is: 0.178265773896
average loss is: 0.168078109015
average loss is: 0.15900931143
average loss is: 0.150884356805
average loss is: 0.143562835396
average loss is: 0.136930837112
average loss is: 0.130894997159
average loss is: 0.125378077089
average loss is: 0.120315633187
average loss is: 0.115653475622
average loss is: 0.111345707807
average loss is: 0.107353201057
average loss is: 0.103642390902
average loss is: 0.100184321725
average loss is: 0.0969538828368
average loss is: 0.0939291894056
average loss is: 0.0910910811149
average loss is: 0.0884227104994
average loss is: 0.0859092032744
average loss is: 0.0835373785728
average loss is: 0.0812955136038
average loss is: 0.0791731475857
average loss is: 0.0771609158713
average loss is: 0.0752504101568
average loss is: 0.0734340592178
average loss is: 0.0717050271845
average loss is: 0.0700571256665
average loss is: 0.0684847396141
average loss is: 0.0669827620572
average loss is: 0.0655465372522
average loss is: 0.0641718128339
average loss is: 0.0628546962203
average loss is: 0.0615916178524
average loss is: 0.0603792975615
average loss is: 0.0592147165184
average loss is: 0.0580950913344
average loss is: 0.0570178513814
average loss is: 0.0559806190546
average loss is: 0.0549811920022
average loss is: 0.0540175269391
average loss is: 0.0530877257938
average loss is: 0.0521900229302
average loss is: 0.0513227736969
average loss is: 0.0504844442235
average loss is: 0.0496736022536
average loss is: 0.0488889090025
average loss is: 0.0481291114653
average loss is: 0.0473930355647
average loss is: 0.0466795804093
average loss is: 0.0459877123818
average loss is: 0.0453164599289
average loss is: 0.0446649091876
average loss is: 0.0440321997496
average loss is: 0.0434175205679
average loss is: 0.0428201068594
average loss is: 0.042239236579
average loss is: 0.041674227424
average loss is: 0.0411244342562
average loss is: 0.0405892467939
average loss is: 0.0400680867989
average loss is: 0.0395604063634
average loss is: 0.0390656857708
average loss is: 0.0385834318376
average loss is: 0.0381131761705
average loss is: 0.037654473684
average loss is: 0.0372069010154
Our network is now trained. Let’s verify that it indeed learned the xor function:
In [12]:
x.set([0,1])
print "0,1",output.value()
x.set([1,0])
print "1,0",output.value()
x.set([0,0])
print "0,0",output.value()
x.set([1,1])
print "1,1",output.value()
0,1 0.998090803623
1,0 0.998076915741
0,0 0.00135990511626
1,1 0.00213058013469
In case we are curious about the parameter values, we can query them:
In [13]:
W.value()
Out[13]:
array([[ 1.26847982, 1.25287616],
[ 0.91610891, 0.80253637],
[ 3.18741179, -2.58643913],
[-0.82472938, -0.68830448],
[-2.74162889, 3.30151606],
[ 0.2677069 , 0.46926948],
[-2.60197234, -2.61786079],
[ 0.89582258, -0.44721049]])
In [14]:
V.value()
Out[14]:
array([[-2.33788562, -1.54022419, -4.58266163, -0.91096258, -4.88002253,
-0.70912606, -4.09791088, -0.61150461]])
In [15]:
b.value()
Out[15]:
[-1.9798537492752075,
-1.3854612112045288,
1.2350027561187744,
-0.8094932436943054,
1.3227168321609497,
-0.5688062906265259,
0.9074684381484985,
0.21831640601158142]
To summarize¶
Here is a complete program:
In [16]:
# define the parameters
m = Model()
pW = m.add_parameters((8,2))
pV = m.add_parameters((1,8))
pb = m.add_parameters((8))
# renew the computation graph
renew_cg()
# add the parameters to the graph
W = parameter(pW)
V = parameter(pV)
b = parameter(pb)
# create the network
x = vecInput(2) # an input vector of size 2.
output = logistic(V*(tanh((W*x)+b)))
# define the loss with respect to an output y.
y = scalarInput(0) # this will hold the correct answer
loss = binary_log_loss(output, y)
# create training instances
def create_xor_instances(num_rounds=2000):
questions = []
answers = []
for round in xrange(num_rounds):
for x1 in 0,1:
for x2 in 0,1:
answer = 0 if x1==x2 else 1
questions.append((x1,x2))
answers.append(answer)
return questions, answers
questions, answers = create_xor_instances()
# train the network
trainer = SimpleSGDTrainer(m)
total_loss = 0
seen_instances = 0
for question, answer in zip(questions, answers):
x.set(question)
y.set(answer)
seen_instances += 1
total_loss += loss.value()
loss.backward()
trainer.update()
if (seen_instances > 1 and seen_instances % 100 == 0):
print "average loss is:",total_loss / seen_instances
average loss is: 0.725458401442
average loss is: 0.656036808193
average loss is: 0.563800293456
average loss is: 0.473188629244
average loss is: 0.401578919515
average loss is: 0.347210133697
average loss is: 0.30537398648
average loss is: 0.27243115149
average loss is: 0.245902155418
average loss is: 0.22411154042
average loss is: 0.205906257995
average loss is: 0.190473453378
average loss is: 0.177226172269
average loss is: 0.165731058566
average loss is: 0.155661680364
average loss is: 0.146767699362
average loss is: 0.138854031509
average loss is: 0.131766459678
average loss is: 0.125381493949
average loss is: 0.119599098227
average loss is: 0.114337381247
average loss is: 0.109528665657
average loss is: 0.105116533384
average loss is: 0.101053577985
average loss is: 0.0972996741069
average loss is: 0.093820632044
average loss is: 0.0905871372991
average loss is: 0.0875739114509
average loss is: 0.0847590394488
average loss is: 0.0821234288742
average loss is: 0.079650368163
average loss is: 0.0773251660003
average loss is: 0.0751348558335
average loss is: 0.0730679483965
average loss is: 0.0711142273374
average loss is: 0.0692645774255
average loss is: 0.0675108397355
average loss is: 0.0658456894337
average loss is: 0.0642625315812
average loss is: 0.0627554119665
average loss is: 0.0613189413034
average loss is: 0.059948229676
average loss is: 0.0586388300699
average loss is: 0.05738668844
average loss is: 0.0561881021362
average loss is: 0.0550396820511
average loss is: 0.0539383201534
average loss is: 0.0528811609025
average loss is: 0.0518655761557
average loss is: 0.0508891425877
average loss is: 0.0499496224367
average loss is: 0.0490449456893
average loss is: 0.0481731953563
average loss is: 0.0473325925335
average loss is: 0.0465214848134
average loss is: 0.0457383351514
average loss is: 0.0449817118815
average loss is: 0.0442502796927
average loss is: 0.0435427918518
average loss is: 0.0428580828441
average loss is: 0.0421950617608
average loss is: 0.0415527067172
average loss is: 0.0409300591527
average loss is: 0.0403262192239
average loss is: 0.0397403411381
average loss is: 0.0391716292271
average loss is: 0.0386193343495
average loss is: 0.0380827505725
average loss is: 0.0375612118193
average loss is: 0.0370540894219
average loss is: 0.0365607894682
average loss is: 0.0360807502221
average loss is: 0.0356134402267
average loss is: 0.0351583559568
average loss is: 0.0347150203697
average loss is: 0.0342829808685
average loss is: 0.0338618080745
average loss is: 0.0334510939502
average loss is: 0.0330504509121
average loss is: 0.0326595103741
Dynamic Networks¶
Dynamic networks are very similar to static ones, but instead of creating the network once and then calling “set” in each training example to change the inputs, we just create a new network for each training example.
We present an example below. While the value of this may not be clear in
the xor
example, the dynamic approach is very convenient for
networks for which the structure is not fixed, such as recurrent or
recursive networks.
In [17]:
from dynet import *
# create training instances, as before
def create_xor_instances(num_rounds=2000):
questions = []
answers = []
for round in xrange(num_rounds):
for x1 in 0,1:
for x2 in 0,1:
answer = 0 if x1==x2 else 1
questions.append((x1,x2))
answers.append(answer)
return questions, answers
questions, answers = create_xor_instances()
# create a network for the xor problem given input and output
def create_xor_network(pW, pV, pb, inputs, expected_answer):
renew_cg() # new computation graph
W = parameter(pW) # add parameters to graph as expressions
V = parameter(pV)
b = parameter(pb)
x = vecInput(len(inputs))
x.set(inputs)
y = scalarInput(expected_answer)
output = logistic(V*(tanh((W*x)+b)))
loss = binary_log_loss(output, y)
return loss
m2 = Model()
pW = m2.add_parameters((8,2))
pV = m2.add_parameters((1,8))
pb = m2.add_parameters((8))
trainer = SimpleSGDTrainer(m2)
seen_instances = 0
total_loss = 0
for question, answer in zip(questions, answers):
loss = create_xor_network(pW, pV, pb, question, answer)
seen_instances += 1
total_loss += loss.value()
loss.backward()
trainer.update()
if (seen_instances > 1 and seen_instances % 100 == 0):
print "average loss is:",total_loss / seen_instances
average loss is: 0.736730417013
average loss is: 0.725369692743
average loss is: 0.715208243926
average loss is: 0.698906037733
average loss is: 0.667973376453
average loss is: 0.620016210104
average loss is: 0.564173455558
average loss is: 0.511108190748
average loss is: 0.464656613212
average loss is: 0.424903827408
average loss is: 0.390944672838
average loss is: 0.361782596097
average loss is: 0.336552875967
average loss is: 0.314552738269
average loss is: 0.295221981726
average loss is: 0.27811523865
average loss is: 0.262876965393
average loss is: 0.249221329002
average loss is: 0.236916671552
average loss is: 0.225773662324
average loss is: 0.215636288271
average loss is: 0.206374970573
average loss is: 0.197881278039
average loss is: 0.190063834667
average loss is: 0.182845127269
average loss is: 0.176158992879
average loss is: 0.16994863152
average loss is: 0.164165015582
average loss is: 0.158765610311
average loss is: 0.153713339384
average loss is: 0.148975738776
average loss is: 0.14452426397
average loss is: 0.140333718062
average loss is: 0.13638177571
average loss is: 0.132648585576
average loss is: 0.129116437846
average loss is: 0.125769484215
average loss is: 0.122593499324
average loss is: 0.119575678358
average loss is: 0.116704463887
average loss is: 0.113969398874
average loss is: 0.111360997359
average loss is: 0.108870635643
average loss is: 0.106490455879
average loss is: 0.104213282756
average loss is: 0.102032551605
average loss is: 0.0999422444205
average loss is: 0.0979368338955
average loss is: 0.0960112348951
average loss is: 0.094160760665
average loss is: 0.0923810851444
average loss is: 0.0906682085468
average loss is: 0.0890184267577
average loss is: 0.0874283051604
average loss is: 0.0858946543594
average loss is: 0.0844145084265
average loss is: 0.0829851059784
average loss is: 0.0816038727351
average loss is: 0.0802684055211
average loss is: 0.0789764590814
average loss is: 0.0777259325812
average loss is: 0.0765148587798
average loss is: 0.0753413928689
average loss is: 0.0742038039022
average loss is: 0.073100465403
average loss is: 0.072029847966
average loss is: 0.0709905121502
average loss is: 0.0699811016467
average loss is: 0.0690003377412
average loss is: 0.0680470136383
average loss is: 0.0671199895066
average loss is: 0.0662181878878
average loss is: 0.0653405894968
average loss is: 0.0644862291951
average loss is: 0.0636541927901
average loss is: 0.0628436133573
average loss is: 0.062053668331
average loss is: 0.0612835769022
average loss is: 0.0605325971122
average loss is: 0.0598000235481
API tutorial¶
Expression building¶
(note: may have old API in some cases)
In [ ]:
from dynet import *
## ==== Create a new computation graph
# (it is a singleton, we have one at each stage.
# renew_cg() clears the current one and starts anew)
renew_cg()
## ==== Creating Expressions from user input / constants.
x = scalarInput(value)
v = vecInput(dimension)
v.set([1,2,3])
z = matInput(dim1, dim2)
# for example:
z1 = matInput(2, 2)
z1.set([1,2,3,4])
## ==== We can take the value of an expression.
# For complex expressions, this will run forward propagation.
print z.value()
print z.npvalue() # as numpy array
print v.vec_value() # as vector, if vector
print x.scalar_value() # as scalar, if scalar
print x.value() # choose the correct one
## ==== Parameters
# Parameters are things we tune during training.
# Usually a matrix or a vector.
# First we create a model and add the parameters to it.
m = Model()
pW = m.add_parameters((8,8)) # an 8x8 matrix
pb = m.add_parameters(8)
# then we create an Expression out of the model's parameters
W = parameter(pW)
b = parameter(pb)
## ===== Lookup parameters
# Similar to parameters, but are representing a "lookup table"
# that maps numbers to vectors.
# These are used for embedding matrices.
# for example, this will have VOCAB_SIZE rows, each of DIM dimensions.
lp = m.add_lookup_parameters((VOCAB_SIZE, DIM))
# lookup parameters can be initialized from an existing array, i.e:
# m["lookup"].init_from_array(wv)
e5 = lookup(lp, 5) # create an Expression from row 5.
e5c = lookup(lp, 5, update=False) # as before, but don't update when optimizing.
e5.set(9) # now the e5 expression contains row 9
e5c.set(9) # ditto
## ===== Combine expression into complex expressions.
# Math
e = e1 + e2
e = e1 * e2 # for vectors/matrices: matrix multiplication (like e1.dot(e2) in numpy)
e = e1 - e2
e = -e1
e = dot_product(e1, e2)
e = cmult(e1, e2) # component-wise multiply (like e1*e2 in numpy)
e = cdiv(e1, e2) # component-wise divide
e = colwise_add(e1, e2) # column-wise addition
# Matrix Shapes
e = reshape(e1, new_dimension)
e = transpose(e1)
# Per-element unary functions.
e = tanh(e1)
e = exp(e1)
e = log(e1)
e = logistic(e1) # Sigmoid(x)
e = rectify(e1) # Relu (= max(x,0))
e = softsign(e1) # x/(1+|x|)
# softmaxes
e = softmax(e1)
e = log_softmax(e1, restrict=[]) # restrict is a set of indices.
# if not empty, only entries in restrict are part
# of softmax computation, others get 0.
e = sum_cols(e1)
# Picking values from vector expressions
e = pick(e1, k) # k is unsigned integer, e1 is vector. return e1[k]
e = e1[k] # same
e = pickrange(e1, k, v) # like python's e1[k:v] for lists. e1 is an Expression, k,v integers.
e = e1[k:v] # same
e = pickneglogsoftmax(e1, k) # k is unsigned integer. equiv to: (pick(-log(softmax(e1)), k))
# Neural net stuff
noise(e1, stddev) # add a noise to each element from a gausian with standard-dev = stddev
dropout(e1, p) # apply dropout with probability p
# functions over lists of expressions
e = esum([e1, e2, ...]) # sum
e = average([e1, e2, ...]) # average
e = concatenate_cols([e1, e2, ...]) # e1, e2,.. are column vectors. return a matrix. (sim to np.hstack([e1,e2,...])
e = concatenate([e1, e2, ...]) # concatenate
e = affine_transform([e0,e1,e2, ...]) # e = e0 + ((e1*e2) + (e3*e4) ...)
## Loss functions
e = squared_distance(e1, e2)
e = l1_distance(e1, e2)
e = huber_distance(e1, e2, c=1.345)
# e1 must be a scalar that is a value between 0 and 1
# e2 (ty) must be a scalar that is a value between 0 and 1
# e = ty * log(e1) + (1 - ty) * log(1 - e1)
e = binary_log_loss(e1, e2)
# e1 is row vector or scalar
# e2 is row vector or scalar
# m is number
# e = max(0, m - (e1 - e2))
e = pairwise_rank_loss(e1, e2, m=1.0)
# Convolutions
# e1 \in R^{d x s} (input)
# e2 \in R^{d x m} (filter)
e = conv1d_narrow(e1, e2) # e = e1 *conv e2
e = conv1d_wide(e1, e2) # e = e1 *conv e2
e = filter1d_narrow(e1, e2) # e = e1 *filter e2
e = kmax_pooling(e1, k) # kmax-pooling operation (Kalchbrenner et al 2014)
e = kmh_ngram(e1, k) #
e = fold_rows(e1, nrows=2) #
Recipe¶
In [6]:
from dynet import *
# create model
m = Model()
# add parameters to model
pW = m.add_parameters((10,30))
pB = m.add_parameters(10)
lookup = m.add_lookup_parameters((500, 10))
print "added"
# create trainer
trainer = SimpleSGDTrainer(m)
# Regularization is set via the --dynet-l2 commandline flag.
# Learning rate parameters can be passed to the trainer:
# alpha = 0.1 # learning rate
# trainer = SimpleSGDTrainer(m, e0=alpha)
# function for graph creation
def create_network_return_loss(inputs, expected_output):
"""
inputs is a list of numbers
"""
renew_cg()
W = parameter(pW) # from parameters to expressions
b = parameter(pB)
emb_vectors = [lookup[i] for i in inputs]
net_input = concatenate(emb_vectors)
net_output = softmax( (W*net_input) + b)
loss = -log(pick(net_output, expected_output))
return loss
# function for prediction
def create_network_return_best(inputs):
"""
inputs is a list of numbers
"""
renew_cg()
W = parameter(pW)
b = parameter(pB)
emb_vectors = [lookup[i] for i in inputs]
net_input = concatenate(emb_vectors)
net_output = softmax( (W*net_input) + b)
return np.argmax(net_output.npvalue())
# train network
for epoch in xrange(5):
for inp,lbl in ( ([1,2,3],1), ([3,2,4],2) ):
print inp, lbl
loss = create_network_return_loss(inp, lbl)
print loss.value() # need to run loss.value() for the forward prop
loss.backward()
trainer.update()
print create_network_return_best([1,2,3])
added
[1, 2, 3] 1
2.71492385864
[3, 2, 4] 2
2.48228144646
[1, 2, 3] 1
2.00279903412
[3, 2, 4] 2
1.82602763176
[1, 2, 3] 1
1.44809651375
[3, 2, 4] 2
1.34181213379
[1, 2, 3] 1
1.03570735455
[3, 2, 4] 2
0.988352060318
[1, 2, 3] 1
0.744616270065
[3, 2, 4] 2
0.732948303223
1
Recipe (using classes)¶
In [4]:
from dynet import *
# create model
m = Model()
# create a class encapsulating the network
class OurNetwork(object):
# The init method adds parameters to the model.
def __init__(self, model):
self.pW = model.add_parameters((10,30))
self.pB = model.add_parameters(10)
self.lookup = model.add_lookup_parameters((500,10))
# the __call__ method applies the network to an input
def __call__(self, inputs):
W = parameter(self.pW)
b = parameter(self.pB)
lookup = self.lookup
emb_vectors = [lookup[i] for i in inputs]
net_input = concatenate(emb_vectors)
net_output = softmax( (W*net_input) + b)
return net_output
def create_network_return_loss(self, inputs, expected_output):
renew_cg()
out = self(inputs)
loss = -log(pick(out, expected_output))
return loss
def create_network_return_best(self, inputs):
renew_cg()
out = self(inputs)
return np.argmax(out.npvalue())
# create network
network = OurNetwork(m)
# create trainer
trainer = SimpleSGDTrainer(m)
# train network
for epoch in xrange(5):
for inp,lbl in ( ([1,2,3],1), ([3,2,4],2) ):
print inp, lbl
loss = network.create_network_return_loss(inp, lbl)
print loss.value() # need to run loss.value() for the forward prop
loss.backward()
trainer.update()
print
print network.create_network_return_best([1,2,3])
[1, 2, 3] 1
2.5900914669
[3, 2, 4] 2
2.00347089767
[1, 2, 3] 1
1.98409461975
[3, 2, 4] 2
1.50869822502
[1, 2, 3] 1
1.50195622444
[3, 2, 4] 2
1.12316584587
[1, 2, 3] 1
1.12293696404
[3, 2, 4] 2
0.831095397472
[1, 2, 3] 1
0.833912611008
[3, 2, 4] 2
0.61754822731
1
or, alternatively, have the training outside of the network class¶
In [ ]:
# create network
network = OurNetwork(m)
# create trainer
trainer = SimpleSGDTrainer(m)
# train network
for epoch in xrange(5):
for inp,lbl in ( ([1,2,3],1), ([3,2,4],2) ):
print inp, lbl
renew_cg()
out = network(inp)
loss = -log(pick(out, lbl))
print loss.value() # need to run loss.value() for the forward prop
loss.backward()
trainer.update()
print
print np.argmax(network([1,2,3]).npvalue())
[1, 2, 3] 1
3.63615298271
[3, 2, 4] 2
3.29473733902
[1, 2, 3] 1
2.81605744362
[3, 2, 4] 2
2.46070289612
[1, 2, 3] 1
2.13946056366
[3, 2, 4] 2
1.77259361744
[1, 2, 3] 1
1.57904195786
[3, 2, 4] 2
1.2269589901
[1, 2, 3] 1
1.13014268875
[3, 2, 4] 2
0.830479979515
1
In [ ]:
RNNs tutorial¶
In [1]:
# we assume that we have the dynet module in your path.
# OUTDATED: we also assume that LD_LIBRARY_PATH includes a pointer to where libcnn_shared.so is.
from dynet import *
An LSTM/RNN overview:¶
An (1-layer) RNN can be thought of as a sequence of cells, \(h_1,...,h_k\), where \(h_i\) indicates the time dimenstion.
Each cell \(h_i\) has an input \(x_i\) and an output \(r_i\). In addition to \(x_i\), cell \(h_i\) receives as input also \(r_{i-1}\).
In a deep (multi-layer) RNN, we don’t have a sequence, but a grid. That is we have several layers of sequences:
- \(h_1^3,...,h_k^3\)
- \(h_1^2,...,h_k^2\)
- \(h_1^1,...h_k^1\),
Let \(r_i^j\) be the output of cell \(h_i^j\). Then:
The input to \(h_i^1\) is \(x_i\) and \(r_{i-1}^1\).
The input to \(h_i^2\) is \(r_i^1\) and \(r_{i-1}^2\), and so on.
The LSTM (RNN) Interface¶
RNN / LSTM / GRU follow the same interface. We have a “builder” which is in charge of creating definining the parameters for the sequence.
In [2]:
model = Model()
NUM_LAYERS=2
INPUT_DIM=50
HIDDEN_DIM=10
builder = LSTMBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, model)
# or:
# builder = SimpleRNNBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, model)
Note that when we create the builder, it adds the internal RNN
parameters to the model
. We do not need to care about them, but they
will be optimized together with the rest of the network’s parameters.
In [3]:
s0 = builder.initial_state()
In [4]:
x1 = vecInput(INPUT_DIM)
In [5]:
s1=s0.add_input(x1)
y1 = s1.output()
# here, we add x1 to the RNN, and the output we get from the top is y (a HIDEN_DIM-dim vector)
In [6]:
y1.npvalue().shape
Out[6]:
(10,)
In [7]:
s2=s1.add_input(x1) # we can add another input
y2=s2.output()
If our LSTM/RNN was one layer deep, y2 would be equal to the hidden state. However, since it is 2 layers deep, y2 is only the hidden state (= output) of the last layer.
If we were to want access to the all the hidden state (the output of
both the first and the last layers), we could use the .h()
method,
which returns a list of expressions, one for each layer:
In [8]:
print s2.h()
(exprssion 54/0, exprssion 66/0)
The same interface that we saw until now for the LSTM, holds also for the Simple RNN:
In [9]:
# create a simple rnn builder
rnnbuilder=SimpleRNNBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, model)
# initialize a new graph, and a new sequence
rs0 = rnnbuilder.initial_state()
# add inputs
rs1 = rs0.add_input(x1)
ry1 = rs1.output()
print "all layers:", s1.h()
all layers: (exprssion 32/0, exprssion 42/0)
In [10]:
print s1.s()
(exprssion 28/0, exprssion 38/0, exprssion 32/0, exprssion 42/0)
To summarize, when calling .add_input(x)
on an RNNState
what
happens is that the state creates a new RNN/LSTM column, passing it: 1.
the state of the current RNN column 2. the input x
The state is then returned, and we can call it’s output()
method to
get the output y
, which is the output at the top of the column. We
can access the outputs of all the layers (not only the last one) using
the .h()
method of the state.
``.s()`` The internal state of the RNN may be more involved than
just the outputs \(h\). This is the case for the LSTM, that keeps an
extra “memory” cell, that is used when calculating \(h\), and which
is also passed to the next column. To access the entire hidden state, we
use the .s()
method.
The output of .s()
differs by the type of RNN being used. For the
simple-RNN, it is the same as .h()
. For the LSTM, it is more
involved.
In [11]:
rnn_h = rs1.h()
rnn_s = rs1.s()
print "RNN h:", rnn_h
print "RNN s:", rnn_s
lstm_h = s1.h()
lstm_s = s1.s()
print "LSTM h:", lstm_h
print "LSTM s:", lstm_s
RNN h: (exprssion 74/0, exprssion 76/0)
RNN s: (exprssion 74/0, exprssion 76/0)
LSTM h: (exprssion 32/0, exprssion 42/0)
LSTM s: (exprssion 28/0, exprssion 38/0, exprssion 32/0, exprssion 42/0)
As we can see, the LSTM has two extra state expressions (one for each hidden layer) before the outputs h.
Extra options in the RNN/LSTM interface¶
Stack LSTM The RNN’s are shaped as a stack: we can remove the top
and continue from the previous state. This is done either by remembering
the previous state and continuing it with a new .add_input()
, or
using we can access the previous state of a given state using the
.prev()
method of state.
Initializing a new sequence with a given state When we call
builder.initial_state()
, we are assuming the state has random /0
initialization. If we want, we can specify a list of expressions that
will serve as the initial state. The expected format is the same as the
results of a call to .final_s()
. TODO: this is not supported yet.
In [12]:
s2=s1.add_input(x1)
s3=s2.add_input(x1)
s4=s3.add_input(x1)
# let's continue s3 with a new input.
s5=s3.add_input(x1)
# we now have two different sequences:
# s0,s1,s2,s3,s4
# s0,s1,s2,s3,s5
# the two sequences share parameters.
assert(s5.prev() == s3)
assert(s4.prev() == s3)
s6=s3.prev().add_input(x1)
# we now have an additional sequence:
# s0,s1,s2,s6
In [13]:
s6.h()
Out[13]:
(exprssion 184/0, exprssion 196/0)
In [14]:
s6.s()
Out[14]:
(exprssion 180/0, exprssion 192/0, exprssion 184/0, exprssion 196/0)
Aside: memory efficient transduction¶
The RNNState
interface is convenient, and allows for incremental
input construction. However, sometimes we know the sequence of inputs in
advance, and care only about the sequence of output expressions. In this
case, we can use the add_inputs(xs)
method, where xs
is a list
of Expression.
In [15]:
state = rnnbuilder.initial_state()
xs = [x1,x1,x1]
states = state.add_inputs(xs)
outputs = [s.output() for s in states]
hs = [s.h() for s in states]
print outputs, hs
[exprssion 200/0, exprssion 206/0, exprssion 212/0] [(exprssion 198/0, exprssion 200/0), (exprssion 203/0, exprssion 206/0), (exprssion 209/0, exprssion 212/0)]
This is convenient.
What if we do not care about .s()
and .h()
, and do not need to
access the previous vectors? In such cases we can use the
transduce(xs)
method instead of add_inputs(xs)
. transduce
takes in a sequence of Expression
s, and returns a sequence of
Expression
s. As a consequence of not returning RNNState
s,
trnasduce
is much more memory efficient than add_inputs
or a
series of calls to add_input
.
In [16]:
state = rnnbuilder.initial_state()
xs = [x1,x1,x1]
outputs = state.transduce(xs)
print outputs
[exprssion 216/0, exprssion 222/0, exprssion 228/0]
Character-level LSTM¶
Now that we know the basics of RNNs, let’s build a character-level LSTM language-model. We have a sequence LSTM that, at each step, gets as input a character, and needs to predict the next character.
In [17]:
import random
from collections import defaultdict
from itertools import count
import sys
LAYERS = 2
INPUT_DIM = 50
HIDDEN_DIM = 50
characters = list("abcdefghijklmnopqrstuvwxyz ")
characters.append("<EOS>")
int2char = list(characters)
char2int = {c:i for i,c in enumerate(characters)}
VOCAB_SIZE = len(characters)
In [18]:
model = Model()
srnn = SimpleRNNBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, model)
lstm = LSTMBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, model)
params = {}
params["lookup"] = model.add_lookup_parameters((VOCAB_SIZE, INPUT_DIM))
params["R"] = model.add_parameters((VOCAB_SIZE, HIDDEN_DIM))
params["bias"] = model.add_parameters((VOCAB_SIZE))
# return compute loss of RNN for one sentence
def do_one_sentence(rnn, sentence):
# setup the sentence
renew_cg()
s0 = rnn.initial_state()
R = parameter(params["R"])
bias = parameter(params["bias"])
lookup = params["lookup"]
sentence = ["<EOS>"] + list(sentence) + ["<EOS>"]
sentence = [char2int[c] for c in sentence]
s = s0
loss = []
for char,next_char in zip(sentence,sentence[1:]):
s = s.add_input(lookup[char])
probs = softmax(R*s.output() + bias)
loss.append( -log(pick(probs,next_char)) )
loss = esum(loss)
return loss
# generate from model:
def generate(rnn):
def sample(probs):
rnd = random.random()
for i,p in enumerate(probs):
rnd -= p
if rnd <= 0: break
return i
# setup the sentence
renew_cg()
s0 = rnn.initial_state()
R = parameter(params["R"])
bias = parameter(params["bias"])
lookup = params["lookup"]
s = s0.add_input(lookup[char2int["<EOS>"]])
out=[]
while True:
probs = softmax(R*s.output() + bias)
probs = probs.vec_value()
next_char = sample(probs)
out.append(int2char[next_char])
if out[-1] == "<EOS>": break
s = s.add_input(lookup[next_char])
return "".join(out[:-1]) # strip the <EOS>
# train, and generate every 5 samples
def train(rnn, sentence):
trainer = SimpleSGDTrainer(model)
for i in xrange(200):
loss = do_one_sentence(rnn, sentence)
loss_value = loss.value()
loss.backward()
trainer.update()
if i % 5 == 0:
print loss_value,
print generate(rnn)
Notice that: 1. We pass the same rnn-builder to do_one_sentence
over
and over again. We must re-use the same rnn-builder, as this is where
the shared parameters are kept. 2. We renew_cg()
before each
sentence – because we want to have a new graph (new network) for this
sentence. The parameters will be shared through the model and the shared
rnn-builder.
In [19]:
sentence = "a quick brown fox jumped over the lazy dog"
train(srnn, sentence)
142.737915039 lvawhaevbxulc yxg esuh vkyb gymj dzcnwgq dcjzzk
84.1147460938 woifoa odp jpt gxjofkaattj
44.212223053 a q io uoopr ouxducmwi jfxa j
23.4485988617 p tctflr
9.73490333557 w
3.23773050308 yaqzteu pux oa rntd bxumu yyvvfalejuyhed over the lazy dog
1.06309330463 a quick browe fow jumped over the lazy dog
0.671298980713 a quick broyn ox jumped over the lazy dog
0.490513861179 a quick brown fox jumped over the lazy dog
0.386095941067 a quick brown fox jumped over the lazy dog
0.318082690239 a quick brown fox jumped over the lazy dog
0.270276993513 a quick brown fox jumped over the lazy dog
0.234851941466 a quick brown foz jumped over the lazy dog
0.207555636764 a quick brown fox jumped over the lazy dog
0.185884565115 a quick brown fox jumped over the lazy dog
0.168265148997 a quiuk brown fox jumped over jhe lazy dog
0.153665527701 a quick brown fox jumped over the lazy dog
0.141367897391 a quick brown fox jumped over the lazy dog
0.130873680115 a quick brown fox jumped over the lazy dog
0.121810980141 a quick brown fox jumped over the lazy dog
0.113908931613 a quick brown fox jumped over the lazy dog
0.106958284974 a quick brown fox jumped over the lazy dog
0.100796818733 a quick brown fox jumped over the lazy dog
0.0953008085489 a quick brown fox jumped over the lazy dog
0.090367347002 a zuick brown for jumped over the lazy dog
0.0859087407589 a quick brown fox jumped over the lazy dog
0.0818664133549 a quick brown fox jumped over the lazy dog
0.0781841799617 a quick brown fox jumped over the lazy dog
0.0748091414571 a quick brown fox jumped over the lazy dog
0.0717144161463 a quick brown fox jumped over the lazy dog
0.0688648074865 a quick brown fox jumped over the lazy dog
0.0662328600883 a quick brown fox jumped over the lazy dog
0.0637853741646 a quick brown fox jumped over the lazy dog
0.0615109689534 a quick brown fox jumped over the lazy dog
0.0593910999596 a quick brown fox jumped over the lazy dog
0.0574130378664 a quick brown fox jumped over the lazy dog
0.0555621087551 a quick brown fox jumped over the lazy dog
0.0538215488195 a quick brown fox jumped over the lazy dog
0.0521896965802 a quick brown fox jumped over the lazy dog
0.0506477579474 a quick brown fox jumped over the lazy dog
In [20]:
sentence = "a quick brown fox jumped over the lazy dog"
train(lstm, sentence)
141.891098022 aoyekppy mocalmz xk atc jlg oaddk
128.925964355 hempeyud ki
121.445785522 qpveti fyobec ztmr eioknnueh ehecdvabxmc ydpmdm
110.670722961 z buws lmy vvrw
93.5055999756 vueoa cprlnkrd o ocazk nb olegiep o fftr t
82.1586227417 zj rvsr oej c toz bnarreow fffj
67.430847168 rzfik qoyc ohe hqe oea uitet ou udjkpme oak kdk oe fbu kcz fox dfoprl too o rxat luurnfowrrtj rbtram to url xlj okrr ooe otm hcy roab llsg doy ifzw rrbow rbowwb oke jxpee
54.9477920532 ba uiy doge she ueeze oejv
43.3301696777 qquc crgibbroej oxne ove rr
34.4687461853 uqckk owrbfo og uouk doge l
25.5408306122 reuk lfr own fox juamd ov
18.9417610168 qojn doo broww boan jover txe zacy moen crlw numk fox joge overwa trez quqk browx ox ruor oro fow j uoez kon fror bowe luccmd ogwr foy jodmoed ox
13.1646575928 qucy dov
9.46595668793 wiuuik brttxl laed over tre lazy dog
5.6522898674 rukc irown fox juaped over the lazy dov
3.38144731522 a quick brown fox jumver the lazy dog
1.80010521412 a bfoin fox jumped ovk fox luick brown fox jumped over the lazy dog
1.30616080761 a quic brownn fox jumped over the lazy dog
1.02201879025 a quick brown fox jumped over the lazy dog
0.83735615015 qucck brown fox jcmped over the lazy dog
0.708056390285 a quickz brown fox jumped over the lazy dog
0.612650871277 a quick brown fox jumped over the lazy dog
0.539469838142 a quick brown fox jumped over thel lazy dog
0.481610894203 va quick brown fox jumped over the lazy dog
0.434762001038 a quuck dovtbown fox jumped over the lazy dog
0.396079242229 a quick brown fox jumped over the lazy dog
0.363606244326 a quick brown fox jumped over the laza dog
0.335973978043 a quick brown fox jumped over the lazy dog
0.312186658382 a quick brown fox jumped over the lazy dog
0.291498303413 a quick brown fox qu
0.273335546255 a quick brown fox jumped ove
0.257278442383 a quick brown fox jumped over the lazy dog
0.242971763015 a quick brown fox jumped over the lazy dog
0.230153128505 a quick brown fox jumped over the lazy dog
0.218599274755 a quick brown fox jumped over the lazy dog
0.208135351539 a quick brown fox jumped over the lazy dog
0.198613137007 a quick brown fox jumped over tie lazy dog
0.189909905195 a quick brown fox jumped over the lazy dog
0.181928783655 a quick brown fox jumped over the lazy dog
0.174587100744 a quick brown fox jumped over the lazy dog
The model seem to learn the sentence quite well.
Somewhat surprisingly, the Simple-RNN model learn quicker than the LSTM!
How can that be?
The answer is that we are cheating a bit. The sentence we are trying to learn has each letter-bigram exactly once. This means a simple trigram model can memorize it very well.
Try it out with more complex sequences.
In [21]:
train(srnn, "these pretzels are making me thirsty")
332.651580811 a quick brown fox jumped over the lazy dog
133.209350586 a quick brown fox jumped over the lazy doe hu yum xd the
65.0720596313 azquick brown fox jumped over ohe iog
31.5592880249 a quick brown fox jumpedrovtretpede pretzelz are makink ma tui idmilt
13.2322559357 theve prwtumpede mhxtjaypny mreticv
1.87829053402 thele pretzelb mre laki loet dre za tuiri mtoina ma qui irwt ere sa taetsdaca qamtuioe ma ick mrolnn mhetsirstyyza qa luijuoethetsepsaaya quirk brmtze ehersjlyaa aumu orkrbtoeqz lrea quijk jrowza quiquihi sakiny mr tui ss thels theqetursy famtzi maethehe iretza lamqzd zretsels area qhirk browna yhetza quirkt rxkwn mox ja isi mq thirsty
0.680327475071 these pretzels are makind me thirsty
0.176128521562 these pretzels are making me thirsty
0.126334354281 these pretzels are making me thirsty
0.10075186193 these pretzels are making me thirsty
0.0846510156989 these pretzels are making me thirsty
0.0734022557735 these pretzels are making me thirsty
0.0650328546762 these pretzels are making me thirsty
0.0585154108703 these pretzels are making me thirsty
0.0532807298005 these pretzels are making me thirsty
0.0489665567875 these pretzels are making me thirsty
0.0453444086015 these pretzels are making me thirsty
0.0422535128891 these pretzels are making me thirsty
0.0395833179355 these pretzels are making me thirsty
0.0372485220432 these mretzels are making me thirsty
0.0351839251816 these pretzels are making me thirsty
0.0333509668708 these pretzels are making me thirsty
0.0317104011774 these pretzels are making me thirsty
0.0302277039737 these pretzels are making me thirsty
0.0288887582719 these pretzels are making me thirsty
0.0276643745601 these pretzels are making me thirsty
0.0265435613692 these pretzels are making me thirsty
0.0255212895572 these pretzels are making me thirsty
0.0245705824345 these pretzels are making me thirsty
0.0236932244152 these pretzels are making me thirsty
0.0228785891086 these pretzels are making me thirsty
0.0221205893904 these pretzels are making me thirsty
0.0214090794325 these pretzels are making me thirsty
0.0207556784153 these pretzels are making me thirsty
0.0201329570264 these pretzels are making me thirsty
0.0195484217256 these pretzels are making me thirsty
0.0190003421158 these pretzels are making me thirsty
0.0184785164893 these pretzels are making me thirsty
0.0179911740124 these pretzels are making me thirsty
0.0175334792584 these pretzels are making me thirsty
In [ ]:
Saving Models¶
In order to save model parameters, the user instead tells the model, at save time, which are the components it is interested in saving. They then need to specify the same components, in the same order, at load time. Notice however that there is no need to specify the sizes etc, as this is handled by the save/load mechanism:
# saving:
from dynet import *
m = Model()
W = m.add_parameters((100,100))
lb = LSTMBuilder(1, 100, 100, m) # this also adds parameters to the model
b = m.add_parameters((30))
m.save("filename", [W,b,lb])
# loading
m = Model()
(W, b, lb) = m.load("filename")
The items that are being passed in the list must adhere to at least one of the following:
- be of type
Parameters
orLookupParameters
(the return types ofadd_parameters
oradd_lookup_parameters
). - be of a built-in “complex” builders such as
LSTMBuilder
orGRUBuilder
that add parameters to the model. - user defined classes that extend to the new
dynet.Saveable
class and implement the required interface.
The Saveable
class is used for easy creation of user-defined “sub networks” that can be saved and loaded as part of the model saving mechanism.
class OneLayerMLP(Saveable):
def __init__(self, model, num_input, num_hidden, num_out, act=tanh):
self.W1 = model.add_parameters((num_hidden, num_input))
self.W2 = model.add_parameters((num_out, num_hidden))
self.b1 = model.add_parameters((num_hidden))
self.b2 = model.add_parameters((num_out))
self.act = act
self.shape = (num_input, num_out)
def __call__(self, input_exp):
W1 = parameter(self.W1)
W2 = parameter(self.W2)
b1 = parameter(self.b1)
b2 = parameter(self.b2)
g = self.act
return softmax(W2*g(W1*input_exp + b1)+b2)
# the Saveable interface requires the implementation
# of the two following methods, specifying all the
# Parameters / LookupParameters / LSTMBuilder / Saveables / etc
# that are directly created by this Saveable.
def get_components(self):
return (self.W1, self.W2, self.b1, self.b2)
def restore_components(self, components):
self.W1, self.W2, self.b1, self.b2 = components
And for the usage:
m = Model()
# create an embedding table.
E = m.add_lookup_parameters((1000,10))
# create an MLP from 10 to 4 with a hidden layer of 20.
mlp = OneLayerMLP(m, 10, 20, 4, rectify)
# use them together.
output = mlp(E[3])
# now save the model:
m.save("filename",[mlp, E])
# now load:
m2 = Model()
mlp2, E2 = m.load("filename")
output2 = mlp2(E2[3])
assert(numpy.array_equal(output2.npvalue(), output.npvalue()))
A more comprehensive tutorial can be found here (EMNLP 2016 tutorial).
Command Line Options¶
All programs using DyNet have a few command line options. These must be specified at the very beginning of the command line, before other options.
--dynet-mem NUMBER
: DyNet runs by default with 512MB of memory, which is split evenly for the forward and backward steps, as well as parameter storage. This will be expanded automatically every time one of the pools runs out of memory. By setting NUMBER here, DyNet will allocate more memory immediately at the initialization stage. Note that you can also individually set the amount of memory for forward calculation, backward calculation, and parameters by using comma separated variables--dynet-mem FOR,BACK,PARAM
. This is useful if, for example, you are performing testing and don’t need to allocate any memory for backward calculation.--dynet-weight-decay NUMBER
: Adds weight decay to the parameters, which modifies each parameter w such that w *= (1-weight_decay) after every update. This is similar to L2 regularization, but different in a couple ways, which are noted in detail in the “Unorthodox Design” section.--dynet-gpus NUMBER
: Specify how many GPUs you want to use, if DyNet is compiled with CUDA. Currently, only one GPU is supported.--dynet-gpu-ids X,Y,Z
: Specify the GPUs that you want to use by device ID. Currently only one GPU is supported, but if you use this command you can select which one to use.
Debugging¶
There are a number of tools to make debugging easier in DyNet.
Visualization¶
It is possible to create visualizations of the computation graph by calling the print_graphviz()
function, which can be helpful to debug. When this functionality is used in Python, it is necessary to add the command line argument --dynet-viz
.
Immediate Computation¶
In general, DyNet performs symbolic execution. This means that you first create the computation graph, then the computation will actually be performed when you request a value using functions such as forward()
or value()
. However, if an error occurs during calculation, this can be hard to debug because the error doesn’t occur immediately where the offending graph node is created. To make debugging simpler, you can use immediate computing mode in dynet. In this mode, every computation gets executed immediately, just like imperative programming, so that you can find exactly where goes wrong.
In C++, you can switch to the immediate computing mode by calling ComputationGraph::set_immediate_compute as follows:
ComputationGraph cg;
cg.set_immediate_compute(true);
Further, dynet can automatically check validity of your model, i.e., detecting Inf/NaN, if it is in immediate computing mode. To activate checking validity, you can add the following code after switching to immediate computing mode.
cg.set_check_validity(true);
In Python, these values can be set by using optional arguments to the renew_cg()
function as follows:
dy.renew_cg(immediate_compute = True, check_validity = True)
Python Reference Manual¶
Dynet global parameters¶
DynetParams¶
-
class
dynet.
DynetParams
¶ This object holds the global parameters of Dynet
You should only need to use this after importing dynet as :
import _dynet / import _gdynetSee the documentation for more details
-
from_args
(shared_parameters=None)¶ Gets parameters from the command line arguments
You can still modify the parameters after calling this. See the documentation about command line arguments for more details
Keyword Arguments: shared_parameters ([type]) – [description] (default: None)
-
init
()¶ Initialize dynet with the current dynetparams object.
This is one way, you can’t uninitialize dynet
-
set_mem
(mem)¶ Set the memory allocated to dynet
The unit is MB
Parameters: mem (number) – memory size in MB
-
set_random_seed
(random_seed)¶ Set random seed for dynet
Parameters: random_seed (number) – Random seed
-
set_requested_gpus
(requested_gpus)¶ Number of requested gpus
Currently only 1 is supported
Parameters: requested_gpus (number) – number of requested gpus
Shared parameters
Parameters: shared_parameters (bool) – shared parameters
-
set_weight_decay
(weight_decay)¶ Set weight decay parameter
Parameters: weight_decay (float) – weight decay parameter
-
Initialization functions¶
-
dynet.
init
(shared_parameters=None)¶ Initialize dynet
Initializes dynet from command line arguments. Do not use after
import dynetonly after
import _dynet / import _gdynetKeyword Arguments: shared_parameters (bool) – [description] (default: None)
-
dynet.
init_from_params
(params)¶ Initialize from DynetParams
Same as
params.init()Parameters: params (DynetParams) – dynet parameters
Model and Parameters¶
Model¶
-
class
dynet.
Model
¶ A model holds Parameters. Use it to create, load and save parameters.
-
add_lookup_parameters
(dim, init=None)¶ Add a lookup parameter to the model
Parameters: dim (tuple) – Shape of the parameter. The first dimension is the lookup dimension Keyword Arguments: init (dynet.PyInitializer) – Initializer (default: GlorotInitializer) Returns: Created LookupParameter Return type: (dynet.LookupParameters)
-
add_parameters
(dim, init=None)¶ Add a parameter to the model
Parameters: dim (tuple) – Shape of the parameter Keyword Arguments: init (dynet.PyInitializer) – Initializer (default: GlorotInitializer) Returns: Created Parameter Return type: (dynet.Parameters)
-
from_file
(fname)¶ Create model from file
Loads all parameters in file and returns model holding them
Parameters: fname (str) – File name Returns: Created model Return type: (dynet.Model)
-
load
(fname)¶ Load a list of parameters from file
Parameters: fname (str) – File name Returns: List of parameters loaded from file Return type: (list)
-
load_all
(fname)¶ Load all parameters in model from file
Parameters: fname (str) – File name
-
parameters_from_numpy
(array)¶ Create parameter from numpy array
Parameters: array (np.ndarray) – Numpy array Returns: Parameter Return type: (dynet.Parameters)
-
save
(fname, components=None)¶ Save a list of parameters to file
Parameters: fname (str) – File name Keyword Arguments: components (list) – List of parameters to save (default: None)
-
save_all
(fname)¶ Save all parameters in model to file
Parameters: fname (str) – File name
-
Parameters and LookupParameters¶
-
class
dynet.
Parameters
¶ Parameters class
Parameters are things that are optimized. in contrast to a system like Torch where computational modules may have their own parameters, in DyNet parameters are just parameters.
-
as_array
()¶ Return as a numpy array.
Returns: values of the parameter Return type: np.ndarray
-
clip_inplace
(left, right)¶ Clip the values in the parameter to a fixed range [left, right] (in place)
Returns: None
-
expr
(update=True)¶ Returns the parameter as an expression
This is the same as calling
dy.parameter(param)Parameters: update (bool) – If this is set to False, the parameter won’t be updated during the backward pass Returns: Expression of the parameter Return type: Expression
-
get_index
()¶ Get parameter index
Returns: Index of the parameter Return type: unsigned
-
grad_as_array
()¶ Return gradient as a numpy array.
Returns: values of the gradient w.r.t. this parameter Return type: np.ndarray
-
is_updated
()¶ check whether the parameter is updated or not
Returns: Update status Return type: bool
-
load_array
(arr)¶ Deprecated
-
scale
(s)¶ Scales the parameter
Parameters: s (float) – Scale
-
set_updated
(b)¶ Set parameter as “updated”
Parameters: b (bool) – updated status
-
shape
()¶ [summary]
[description]
Returns: [description] Return type: [type]
-
zero
()¶ Set the parameter to zero
-
Parameters initializers¶
-
class
dynet.
PyInitializer
¶ Base class for parameter initializer
-
class
dynet.
NormalInitializer
(mean=0, var=1)¶ Bases:
dynet.PyInitializer
Initialize the parameters with a gaussian distribution
Keyword Arguments: - mean (number) – Mean of the distribution (default: 0)
- var (number) – Variance of the distribution (default: 1)
-
class
dynet.
UniformInitializer
(scale)¶ Bases:
dynet.PyInitializer
Initialize the parameters with a uniform distribution
Parameters: scale (number) – Parmeters are sampled from \(\mathcal U([-\texttt{scale},\texttt{scale}])\)
-
class
dynet.
ConstInitializer
(c)¶ Bases:
dynet.PyInitializer
Initialize the parameters with a constant value
Parameters: c (number) – Value to initialize the parameters
-
class
dynet.
IdentityInitializer
¶ Bases:
dynet.PyInitializer
Initialize the parameters as the identity
Only works with square matrices
-
class
dynet.
GlorotInitializer
(is_lookup=False, gain=1.0)¶ Bases:
dynet.PyInitializer
Initializes the weights according to Glorot & Bengio (2011)
If the dimensions of the parameter matrix are \(m,n\), the weights are sampled from \(\mathcal U([-g\sqrt{\frac{6}{m+n}},g\sqrt{\frac{6}{m+n}}])\)
The gain \(g\) depends on the activation function :
- \(\text{tanh}\) : 1.0
- \(\text{ReLU}\) : 0.5
- \(\text{sigmoid}\) : 4.0
- Any smooth function \(f\) : \(\frac{1}{f'(0)}\)
Keyword Arguments: - is_lookup (bool) – Whether the parameter is alookup parameter (default: False)
- gain (number) – Gain (Depends on the activation function) (default: 1.0)
-
class
dynet.
SaxeInitializer
(scale=1.0)¶ Bases:
dynet.PyInitializer
Initializes according to Saxe et al. (2014)
- Initializes as a random orthonormal matrix (unimplemented for GPU)
- Keyword Arguments:
- scale (number): scale to apply to the orthonormal matrix
-
class
dynet.
FromFileInitializer
(fname)¶ Bases:
dynet.PyInitializer
Initialize parameter from file
Parameters: fname (str) – File name
-
class
dynet.
NumpyInitializer
(array)¶ Bases:
dynet.PyInitializer
Initialize from numpy array
Alternatively, use
Model.parameters_from_numpy()
Parameters: array (np.ndarray) – Numpy array
Computation Graph¶
-
dynet.
renew_cg
(immediate_compute=False, check_validity=False)¶ Renew the computation graph.
Call this before building any new computation graph
-
dynet.
cg_version
()¶ Varsion of the current computation graph
-
dynet.
print_text_graphviz
()¶
-
dynet.
cg_checkpoint
()¶ Saves the state of the computation graph
-
dynet.
cg_revert
()¶ Revert the computation graph state to the previous checkpoint
-
dynet.
cg
()¶ Get the current ComputationGraph
-
class
dynet.
ComputationGraph
¶ Computation graph object
While the ComputationGraph is central to the inner workings of DyNet, from the user’s perspective, the only responsibility is to create a new computation graph for each training example.
-
parameters
(params)¶ Same as
dynet.parameters(params)
-
renew
(immediate_compute=False, check_validity=False)¶ Same as
dynet.renew_cg()
-
version
()¶ Same as
dynet.cg_version()
-
Operations¶
Expressions¶
-
class
dynet.
Expression
¶ Expressions are the building block of a Dynet computation graph.
Expressions are the main data types being manipulated in a DyNet program. Each expression represents a sub-computation in a computation graph.
-
backward
(full=False)¶ Run the backward pass based on this expression
The parameter
full
specifies whether the gradients should be computed for all nodes (True
) or only non-constant nodes (False
).By default, a node is constant unless
- it is a parameter node
- it depends on a non-constant node
Thus, functions of constants and inputs are considered as constants.
Turn
full
on if you want to retrieve gradients w.r.t. inputs for instance. By default this is turned off, so that the backward pass ignores nodes which have no influence on gradients w.r.t. parameters for efficiency.Parameters: full (bool) – Whether to compute all gradients (including with respect to constant nodes).
-
dim
()¶ Dimension of the expression
Returns a tuple (dims,batch_dim) where dims is the tuple of dimensions of each batch element
Returns: dimension Return type: tuple
-
forward
(recalculate=False)¶ This runs incremental forward on the entire graph
May not be optimal in terms of efficiency. Prefer
values
Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False)
-
gradient
()¶ Returns the value of the expression as a numpy array
The last dimension is the batch size (if it’s > 1).
Make sure to call
backward
on a downstream expression before calling this.If the Expression is a constant expression (meaning it’s not a function of a parameter), dynet won’t compute it’s gradient for the sake of efficiency. You need to manually force the gradient computation by adding the agument
full=True
tobackward
Returns: numpy array of values Return type: np.ndarray
-
npvalue
(recalculate=False)¶ Returns the value of the expression as a numpy array
The last dimension is the batch size (if it’s > 1)
Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False) Returns: numpy array of values Return type: np.ndarray
-
scalar_value
(recalculate=False)¶ Returns value of an expression as a scalar
This only works if the expression is a scalar
Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False) Returns: Scalar value of the expression Return type: float
-
tensor_value
(recalculate=False)¶ Returns the value of the expression as a Tensor.
This is useful if you want to use the value for other on-device calculations that are not part of the computation graph, i.e. using argmax.
Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False) Returns: a dynet Tensor object. Return type: Tensor
-
value
(recalculate=False)¶ Gets the value of the expression in the most relevant format
this returns the same thing as
scalar_value
,vec_value
,npvalue
depending on whether the number of dimensions of the expression is 0, 1 or 2+Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False) Returns: Value of the expression Return type: float, list, np.ndarray
-
vec_value
(recalculate=False)¶ Returns the value of the expression as a vector
In case of a multidimensional expression, the values are flattened according to a column major ordering
Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False) Returns: Array of values Return type: list
-
Operations¶
Operations are used to build expressions
Input operations¶
-
dynet.
parameter
(p, update=True)¶ Load a parameter in the computation graph
Get the expression corresponding to a parameter
Parameters: - p (Parameter,LookupParameter) – Parameter to load (can be a lookup parameter as well)
- update (bool) – If this is set to False, the parameter won’t be updated during the backward pass
Returns: Parameter expression
Return type: Raises: NotImplementedError
– Only works with parameters and lookup parameters
-
dynet.
inputTensor
(arr, batched=False)¶ Creates a tensor expression based on a numpy array or a list.
The dimension is inferred from the shape of the input. if batched=True, the last dimension is used as a batch dimension if arr is a list of numpy ndarrays, this returns a batched expression where the batch elements are the elements of the list
Parameters: arr (list,np.ndarray) – Values : numpy ndarray OR list of np.ndarray OR multidimensional list of floats Keyword Arguments: batched (bool) – Whether to use the last dimension as a batch dimension (default: False) Returns: Input expression Return type: _vecInputExpression Raises: TypeError
– If the type is not respected
-
dynet.
scalarInput
(s)¶
-
dynet.
vecInput
(dim)¶ Input an empty vector
Parameters: dim (number) – Size Returns: Corresponding expression Return type: _vecInputExpression
-
dynet.
inputVector
(v)¶ Input a vector by values
Parameters: v (vector[float]) – Values Returns: Corresponding expression Return type: _vecInputExpression
-
dynet.
matInput
(d1, d2)¶ DEPRECATED : use inputTensor
TODO : remove this
Parameters: - d1 (int) – [description]
- d2 (int) – [description]
Returns: [description]
Return type:
-
dynet.
inputMatrix
(v, d)¶ DEPRECATED : use inputTensor
TODO : remove this
inputMatrix(vector[float] v, tuple d)
Create a matrix literal. First argument is a list of floats (or a flat numpy array). Second argument is a dimension. Returns: an expression. Usage example:
x = inputMatrix([1,2,3,4,5,6],(2,3)) x.npvalue() --> array([[ 1., 3., 5.], [ 2., 4., 6.]])
-
dynet.
lookup
(p, index=0, update=True)¶ Pick an embedding from a lookup parameter and returns it as a expression
param p: Lookup parameter to pick from type p: LookupParameters Keyword Arguments: - index (number) – Lookup index (default: 0)
- update (bool) – Whether to update the lookup parameter [(default: True)
Returns: Expression for the embedding
Return type: _lookupExpression
-
dynet.
lookup_batch
(p, indices, update=True)¶ Look up parameters.
The mini-batched version of lookup. The resulting expression will be a mini-batch of parameters, where the “i”th element of the batch corresponds to the parameters at the position specified by the “i”th element of “indices”
Parameters: - p (LookupParameters) – Lookup parameter to pick from
- indices (list(int)) – Indices to look up for each batch element
Keyword Arguments: update (bool) – Whether to update the lookup parameter (default: True)
Returns: Expression for the batched embeddings
Return type: _lookupBatchExpression
-
dynet.
zeroes
(dim, batch_size=1)¶ Create an input full of zeros
Create an input full of zeros, sized according to dimensions
dim
Parameters: dim (tuple) – Dimension of the tensor Keyword Arguments: batch_size (number) – Batch size of the tensor (default: (1)) Returns: A “d” dimensioned zero tensor Return type: dynet.Expression
-
dynet.
random_normal
(dim, batch_size=1)¶ Create a random normal vector
Create a vector distributed according to normal distribution with mean 0, variance 1.
Parameters: dim (tuple) – Dimension of the tensor Keyword Arguments: batch_size (number) – Batch size of the tensor (default: (1)) Returns: A “d” dimensioned normally distributed tensor Return type: dynet.Expression
-
dynet.
random_bernoulli
(dim, p, scale=1.0, batch_size=1)¶ Create a random bernoulli tensor
Create a tensor distributed according to bernoulli distribution with parameter \(p\).
Parameters: - dim (tuple) – Dimension of the tensor
- p (number) – Parameter of the bernoulli distribution
Keyword Arguments: - scale (number) – Scaling factor to apply to the sampled tensor (default: (1.0))
- batch_size (number) – Batch size of the tensor (default: (1))
Returns: A “d” dimensioned bernoulli distributed tensor
Return type:
-
dynet.
random_uniform
(dim, left, right, batch_size=1)¶ Create a random uniform tensor
Create a tensor distributed according to uniform distribution with boundaries left and right.
Parameters: - dim (tuple) – Dimension of the tensor
- left (number) – Lower bound of the uniform distribution
- right (number) – Upper bound of the uniform distribution
Keyword Arguments: batch_size (number) – Batch size of the tensor (default: (1))
Returns: A “d” dimensioned uniform distributed tensor
Return type:
-
dynet.
noise
(x, stddev)¶ Additive gaussian noise
Add gaussian noise to an expression.
Parameters: - x (dynet.Expression) – Input expression
- stddev (number) – The standard deviation of the gaussian
Returns: \(y\sim\mathcal N(x,\texttt{stddev})\)
Return type:
Arithmetic operations¶
-
dynet.
cdiv
(x, y)¶ Componentwise division
Do a componentwise division where each value is equal to \(\frac{x_i}{y_i}\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: An expression where the ith element is equal to \(\frac{x_i}{y_i}\)
Return type:
-
dynet.
cmult
(x, y)¶ Componentwise multiplication
Do a componentwise multiplication where each value is equal to \(x_i\times y_i\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: An expression where the ith element is equal to \(x_i\times y_i\)
Return type:
-
dynet.
colwise_add
(x, y)¶ Columnwise addition
Add vector \(y\) to each column of matrix \(x\)
Parameters: - x (dynet.Expression) – An MxN matrix
- y (dynet.Expression) – A length M vector
Returns: An expression where \(y\) is added to each column of \(x\)
Return type:
-
dynet.
squared_norm
(x)¶ Squared norm
The squared norm of the values of
x
: \(\Vert x\Vert_2^2=\sum_i x_i^2\).Parameters: x (dynet.Expression) – Input expression Returns: \(\Vert x\Vert_2^2=\sum_i x_i^2\) Return type: dynet.Expression
-
dynet.
tanh
(x)¶ Hyperbolic tangent
Elementwise calculation of the hyperbolic tangent
Parameters: x (dynet.Expression) – Input expression Returns: \(\tanh(x)\) Return type: dynet.Expression
-
dynet.
exp
(x)¶ Natural exponent
Calculate elementwise \(y_i = e^{x_i}\)
Parameters: x (dynet.Expression) – Input expression Returns: \(e^{x}\) Return type: dynet.Expression
-
dynet.
square
(x)¶ Square
Calculate elementwise \(y_i = x_i^2\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y = x^2\) Return type: dynet.Expression
-
dynet.
sqrt
(x)¶ Square root
Calculate elementwise \(y_i = \sqrt{x_i}\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y = \sqrt{x}\) Return type: dynet.Expression
-
dynet.
abs
(x)¶ Absolute value
Calculate elementwise \(y_i = \vert x_i\vert\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y = \vert x\vert\) Return type: dynet.Expression
-
dynet.
erf
(x)¶ Gaussian error function
Elementwise calculation of the Gaussian error function \(y_i = \text{erf}(x_i)=\frac {1}{\sqrt{\pi}}\int_{-x_i}^{x_i}e^{-t^2}\mathrm{d}t\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \text{erf}(x_i)\) Return type: dynet.Expression
-
dynet.
cube
(x)¶ Calculate elementwise \(y_i = x_i^3\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y = x^3\) Return type: dynet.Expression
-
dynet.
log
(x)¶ Natural logarithm
Elementwise calculation of the natural logarithm \(y_i = \ln(x_i)\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \ln(x_i)\) Return type: dynet.Expression
-
dynet.
lgamma
(x)¶ Log gamma
Calculate elementwise log gamma function \(y_i = \ln(\Gamma(x_i))\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \ln(\Gamma(x_i))\) Return type: dynet.Expression
-
dynet.
logistic
(x)¶ Logistic sigmoid function
Calculate elementwise \(y_i = \frac{1}{1+e^{-x_i}}\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \frac{1}{1+e^{-x_i}}\) Return type: dynet.Expression
-
dynet.
rectify
(x)¶ Rectifier (or ReLU, Rectified Linear Unit)
Calculate elementwise recitifer (ReLU) function \(y_i = \max(x_i,0)\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \max(x_i,0)\) Return type: dynet.Expression
-
dynet.
sparsemax
(x)¶ Sparsemax
The sparsemax function (Martins et al. 2016), which is similar to softmax, but induces sparse solutions where most of the vector elements are zero. Note: This function is not yet implemented on GPU.
Parameters: x (dynet.Expression) – Input expression Returns: The sparsemax of the scores Return type: dynet.Expression
-
dynet.
softsign
(x)¶ Softsign function
Calculate elementwise the softsign function \(y_i = \frac{x_i}{1+\vert x_i\vert}\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \frac{x_i}{1+\vert x_i\vert}\) Return type: dynet.Expression
-
dynet.
pow
(x, y)¶ Power function
Calculate an output where the ith element is equal to \(x_i^{y_i}\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(x_i^{y_i}\)
Return type:
-
dynet.
bmin
(x, y)¶ Minimum
Calculate an output where the ith element is \(\min(x_i,y_i)\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(\min(x_i,y_i)\)
Return type:
-
dynet.
bmax
(x, y)¶ Maximum
Calculate an output where the ith element is \(\max(x_i,y_i)\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(\max(x_i,y_i)\)
Return type:
-
dynet.
transpose
(x, dims=[1, 0])¶ Transpose a matrix
Get the transpose of the matrix, or if dims is specified shuffle the dimensions arbitrarily.
Note: This is O(1) if either the row or column dimension is 1, and O(n) otherwise.
Parameters: - x (dynet.Expression) – Input expression
- dims (list) – The dimensions to swap. The ith dimension of the output will be equal to the dims[i] dimension of the input. dims must have the same number of dimensions as x.
Returns: \(x^T\) / the shuffled expression
Return type:
-
dynet.
sum_cols
(x)¶ [summary]
[description]
Parameters: x (dynet.Expression) – Returns: Return type: dynet.Expression
-
dynet.
sum_elems
(x)¶ Sum all elements
Sum all the elements in an expression.
Parameters: x (dynet.Expression) – Input expression Returns: The sum of all of its elements Return type: dynet.Expression
-
dynet.
sum_batches
(x)¶ Sum over minibatches
Sum an expression that consists of multiple minibatches into one of equal dimension but with only a single minibatch. This is useful for summing loss functions at the end of minibatch training.
Parameters: x (dynet.Expression) – Input expression Returns: An expression with a single batch Return type: dynet.Expression
-
dynet.
fold_rows
(x, nrows=2)¶ [summary]
[description]
Parameters: x (dynet.Expression) – Keyword Arguments: nrows {number} (unsigned) – (default: (2)) Returns: Return type: dynet.Expression
-
dynet.
esum
(xs)¶ Sum
This performs an elementwise sum over all the expressions in
xs
Parameters: xs (list) – A list of expression of same dimension Returns: An expression where the ith element is equal to \(\sum_{j=0}\texttt{xs[}j\texttt{][}i\texttt{]}\) Return type: dynet.Expression
-
dynet.
logsumexp
(xs)¶ Log, sum, exp
The elementwise “logsumexp” function that calculates \(\ln(\sum_i e^{xs_i})\), used in adding probabilities in the log domain.
Parameters: xs (list) – A list of expression of same dimension Returns: An expression where the ith element is equal to \(\ln\left(\sum_{j=0}e^{\texttt{xs[}j\texttt{][}i\texttt{]}}\right)\) Return type: dynet.Expression
-
dynet.
average
(xs)¶ Average
This performs an elementwise average over all the expressions in
xs
Parameters: xs (list) – A list of expression of same dimension Returns: An expression where the ith element is equal to \(\frac{1}{\texttt{len(xs)}}\sum_{j=0}\texttt{xs[}j\texttt{][}i\texttt{]}\) Return type: dynet.Expression
-
dynet.
emax
(xs)¶ Max
This performs an elementwise max over all the expressions in
xs
Parameters: xs (list) – A list of expression of same dimension Returns: An expression where the ith element is equal to \(\max_j\texttt{xs[}j\texttt{][}i\texttt{]}\) Return type: dynet.Expression
Loss/Probability operations¶
-
dynet.
softmax
(x)¶ Softmax
The softmax function normalizes each column to ensure that all values are between 0 and 1 and add to one by applying the \(\frac{e^{x_i}}{sum_j e^{x_j}}\).
Parameters: x (dynet.Expression) – Input expression Returns: \(\frac{e^{x_i}}{\sum_j e^{x_j}}\) Return type: dynet.Expression
-
dynet.
log_softmax
(x, restrict=None)¶ Restricted log softmax
The log softmax function calculated over only a subset of the vector elements. The elements to be included are set by the
restriction
variable. All elements not included inrestriction
are set to negative infinity.Parameters: x (dynet.Expression) – Input expression Keyword Arguments: restrict (list) – List of log softmax to compute (default: (None)) Returns: A vector with the log softmax over the specified elements Return type: dynet.Expression
-
dynet.
pairwise_rank_loss
(x, y, m=1.0)¶ Pairwise rank loss
A margin-based loss, where every margin violation for each pair of values is penalized: \(\sum_i \max(x_i-y_i+m, 0)\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Keyword Arguments: m (number) – The margin (default: (1.0))
Returns: The pairwise rank loss
Return type:
-
dynet.
poisson_loss
(x, y)¶ Poisson loss
The negative log probability of
y
according to a Poisson distribution with parameterx
. Useful in Poisson regression where, we try to predict the parameters of a Possion distribution to maximize the probability of datay
.Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: The Poisson loss
Return type:
-
dynet.
huber_distance
(x, y, c=1.345)¶ Huber distance
The huber distance between values of
x
andy
parameterized byc
, \(\sum_i L_c(x_i, y_i)\) where:\[\begin{split}L_c(x, y) = \begin{cases}{lr} \frac{1}{2}(y - x)^2 & \textrm{for } \vert y - f(x)\vert \le c, \\ c\, \vert y - f(x)\vert - \frac{1}{2}c^2 & \textrm{otherwise.} \end{cases}\end{split}\]Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Keyword Arguments: c (number) – The parameter of the huber distance parameterizing the cuttoff (default: (1.345))
Returns: The huber distance
Return type:
-
dynet.
pickneglogsoftmax
(x, v)¶ Negative softmax log likelihood
This function takes in a vector of scores
x
, and performs a log softmax, takes the negative, and selects the likelihood corresponding to the elementv
. This is perhaps the most standard loss function for training neural networks to predict one out of a set of elements.Parameters: - x (dynet.Expression) – Input scores
- v (int) – True class
Returns: \(-\log\left(\frac{e^{x_v}}{\sum_j e^{x_j}}\right)\)
Return type:
-
dynet.
pickneglogsoftmax_batch
(x, vs)¶ Negative softmax log likelihood on a batch
This function takes in a batched vector of scores
x
, and performs a log softmax, takes the negative, and selects the likelihood corresponding to the elementsvs
. This is perhaps the most standard loss function for training neural networks to predict one out of a set of elements.Parameters: - x (dynet.Expression) – Input scores
- v (list) – True classes
Returns: \(-\sum_{v\in \texttt{vs}}\log\left(\frac{e^{x_v}}{\sum_j e^{x_j}}\right)\)
Return type:
-
dynet.
kmh_ngram
(x, v)¶ [summary]
[description]
Parameters: - x (dynet.Expression) –
- v (dynet.Expression) –
Returns: Return type:
-
dynet.
squared_distance
(x, y)¶ Squared distance
The squared distance between values of
x
andy
: \(\Vert x-y\Vert_2^2=\sum_i (x_i-y_i)^2\).Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(\Vert x-y\Vert_2^2=\sum_i (x_i-y_i)^2\)
Return type:
-
dynet.
l1_distance
(x, y)¶ L1 distance
L1 distance between values of
x
andy
: \(\Vert x-y\Vert_1=\sum_i \vert x_i-y_i\vert\).Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(\Vert x-y\Vert_1=\sum_i \vert x_i-y_i\vert\).
Return type:
-
dynet.
binary_log_loss
(x, y)¶ Binary log loss
The log loss of a binary decision according to the sigmoid sigmoid function \(- \sum_i (y_i \ln(x_i) + (1-y_i) \ln(1-x_i))\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(- \sum_i (y_i \ln(x_i) + (1-y_i) \ln(1-x_i))\)
Return type:
Flow/Shaping operations¶
-
dynet.
pick
(e, index=0, dim=0)¶ Pick element.
Pick a single element/row/column/sub-tensor from an expression. This will result in the dimension of the tensor being reduced by 1.
Parameters: e (Expression) – Expression to pick from
Keyword Arguments: - index (number) – Index to pick (default: 0)
- dim (number) – Dimension to pick from (default: 0)
Returns: Picked expression
Return type: _pickerExpression
-
dynet.
pick_batch
(e, indices, dim=0)¶ Batched pick.
Pick elements from multiple batches.
Parameters: - e (Expression) – Expression to pick from
- indices (list) – Indices to pick
- dim (number) – Dimension to pick from (default: 0)
Returns: Picked expression
Return type: _pickerBatchExpression
-
dynet.
pickrange
(x, v, u)¶ Pick range of elements
Pick a range of elements from an expression.
Parameters: - x (dynet.Expression) – input expression
- v (int) – Beginning index
- u (int) – End index
Returns: The value of {x[v],...,x[u]}
Return type:
-
dynet.
pick_batch_elem
(x, v)¶ Pick batch element.
Pick batch element from a batched expression. For a Tensor with 3 batch elements:
\[\begin{split}\begin{pmatrix} x_{1,1,1} & x_{1,1,2} \\ x_{1,2,1} & x_{1,2,2} \\ \end{pmatrix}\\ \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix}\\ \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix}\end{split}\]pick_batch_elem(t, 1)
will return a Tensor of\[\begin{split}\begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix}\end{split}\]Parameters: - x (dynet.Expression) – Input expression
- v (int) – The index of the batch element to be picked.
Returns: The expression of picked batch element. The picked element is a tensor whose batch dimension equals to one.
Return type:
-
dynet.
pick_batch_elems
(x, vs)¶ Pick batch element.
Pick batch element from a batched expression. For a Tensor with 3 batch elements:
\[\begin{split}\begin{pmatrix} x_{1,1,1} & x_{1,1,2} \\ x_{1,2,1} & x_{1,2,2} \\ \end{pmatrix}\\ \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix}\\ \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix}\end{split}\]pick_batch_elems(t, [2, 3])
will return a Tensor of\[\begin{split}\begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix}\\ \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix}\end{split}\]Parameters: - x (dynet.Expression) – Input expression
- vs (list) – A list of indices of the batch elements to be picked.
Returns: The expression of picked batch elements. The batch elements is a tensor whose batch dimension equals to the size of list v.
Return type:
-
dynet.
reshape
(x, d, batch_size=1)¶ Reshape to another size
This node reshapes a tensor to another size, without changing the underlying layout of the data. The layout of the data in DyNet is column-major, so if we have a 3x4 matrix :
\[\begin{split}\begin{pmatrix} x_{1,1} & x_{1,2} & x_{1,3} & x_{1,4} \\ x_{2,1} & x_{2,2} & x_{2,3} & x_{2,4} \\ x_{3,1} & x_{3,2} & x_{3,3} & x_{3,4} \\ \end{pmatrix}\end{split}\]and transform it into a 2x6 matrix, it will be rearranged as:
\[\begin{split}\begin{pmatrix} x_{1,1} & x_{3,1} & x_{2,2} & x_{1,3} & x_{3,3} & x_{2,4} \\ x_{2,1} & x_{1,2} & x_{3,2} & x_{2,3} & x_{1,4} & x_{3,4} \\ \end{pmatrix}\end{split}\]Note: This is O(1) for forward, and O(n) for backward.
Parameters: - x (dynet.Expression) – Input expression
- d (tuple) – New dimension
Keyword Arguments: batch_size (int) – New batch size (default: (1))
Returns: The reshaped expression
Return type:
-
dynet.
select_rows
(x, rs)¶ Select rows
Select a subset of rows of a matrix.
Parameters: - x (dynet.Expression) – Input expression
- rs (list) – The rows to extract
Returns: An expression containing the selected rows
Return type:
-
dynet.
select_cols
(x, cs)¶ Select columns
Select a subset of columns of a matrix.
Parameters: - x (dynet.Expression) – Input expression
- cs (list) – The columns to extract
Returns: An expression containing the selected columns
Return type:
-
dynet.
concatenate_cols
(xs)¶ Concatenate columns
Perform a concatenation of the columns in multiple expressions. All expressions must have the same number of rows.
Parameters: xs (list) – A list of expressions Returns: The expression with the columns concatenated Return type: dynet.Expression
-
dynet.
concatenate
(xs, d=0)¶ Concatenate
Perform a concatenation of multiple expressions along a particular dimension. All expressions must have the same dimensions except for the dimension to be concatenated (rows by default).Parameters: - xs (list) – A list of expressions
- d – The dimension along with to perform concatenation
Returns: The expression concatenated along the particular dimension
Return type:
-
dynet.
concatenate_to_batch
(xs)¶ Concatenate list of expressions to a single batched expression
Perform a concatenation of several expressions along the batch dimension. All expressions must have the same shape except for the batch dimension.
Parameters: xs (list) – A list of expressions of same dimension (except batch size) Returns: The expression with the batch dimensions concatenated Return type: dynet.Expression
-
dynet.
max_dim
(x, d=0)¶ Max out through a dimension
Select out a element/row/column/sub-tensor from an expression, with maximum value along a given dimension. This will result in the dimension of the expression being reduced by 1.
Parameters: x (dynet.Expression) – Input expression Keyword Arguments: d (int) – Dimension on which to perform the maxout (default: (0)) Returns: An expression of sub-tensor with max value along dimension d
Return type: dynet.Expression
-
dynet.
min_dim
(x, d=0)¶ Min out through a dimension
Select out a element/row/column/sub-tensor from an expression, with minimum value along a given dimension. This will result in the dimension of the expression being reduced by 1.
Parameters: x (dynet.Expression) – Input expression Keyword Arguments: d (int) – Dimension on which to perform the minout (default: (0)) Returns: An expression of sub-tensor with min value along dimension d
Return type: dynet.Expression
-
dynet.
nobackprop
(x)¶ Prevent backprop
This node has no effect on the forward pass, but prevents gradients from flowing backward during the backward pass. This is useful when there’s a subgraph for which you don’t want loss passed back to the parameters.
Parameters: x (dynet.Expression) – Input expression Returns: An output expression containing the same as input (only effects on backprop process) Return type: dynet.Expression
-
dynet.
flip_gradient
(x)¶ Negative backprop
This node has no effect on the forward pass, but takes negative on backprop process. This operation is widely used in adversarial networks.
Parameters: x (dynet.Expression) – Input expression Returns: An output expression containing the same as input (only effects on backprop process) Return type: dynet.Expression
Noise operations¶
-
dynet.
noise
(x, stddev) Additive gaussian noise
Add gaussian noise to an expression.
Parameters: - x (dynet.Expression) – Input expression
- stddev (number) – The standard deviation of the gaussian
Returns: \(y\sim\mathcal N(x,\texttt{stddev})\)
Return type:
-
dynet.
dropout
(x, p)¶ Dropout
With a fixed probability, drop out (set to zero) nodes in the input expression, and scale the remaining nodes by 1/p. Note that there are two kinds of dropout:
- Regular dropout: where we perform dropout at training time and then scale outputs by p at test time.
- Inverted dropout: where we perform dropout and scaling at training time, and do not need to do anything at test time.
DyNet implements the latter, so you only need to apply dropout at training time, and do not need to perform scaling and test time.
Parameters: - x (dynet.Expression) – Input expression
- p (dynet.Expression) – The dropout probability
Returns: The dropped out expression \(y=\frac{1}{1-\texttt{p}}x\circ z, z\sim\text{Bernoulli}(1-\texttt{p})\)
Return type:
-
dynet.
block_dropout
(x, p)¶ Block dropout
Identical to the dropout operation, but either drops out all or no values in the expression, as opposed to making a decision about each value individually.
Parameters: - x (dynet.Expression) – Input expression
- p (dynet.Expression) – The dropout probability
Returns: The block dropout expression
Return type:
Linear algebra operations¶
-
dynet.
affine_transform
(exprs)¶ Affine transform
This performs an affine transform over an arbitrary (odd) number of expressions held in the input initializer list xs. The first expression is the “bias,” which is added to the expression as-is. The remaining expressions are multiplied together in pairs, then added. A very common usage case is the calculation of the score for a neural network layer (e.g. \(b + Wz\)) where b is the bias, W is the weight matrix, and z is the input. In this case
xs[0] = b
,xs[1] = W
, andxs[2] = z
.Parameters: exprs (list) – A list containing an odd number of expressions Returns: An expression equal to: xs[0] + xs[1]*xs[2] + xs[3]*xs[4] + ...
Return type: dynet.Expression
-
dynet.
dot_product
(x, y)¶ Dot Product
Calculate the dot product \(x^Ty=\sum_i x_iy_i\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(x^Ty=\sum_i x_iy_i\)
Return type:
-
dynet.
inverse
(x)¶ Matrix Inverse
Takes the inverse of a matrix (not implemented on GPU yet, although contributions are welcome: issue). Note that back-propagating through an inverted matrix can also be the source of stability problems sometimes.
Parameters: x (dynet.Expression) – Input expression Returns: Inverse of x Return type: dynet.Expression
-
dynet.
trace_of_product
(x, y)¶ Trace of Matrix Product
Takes the trace of the product of matrices. (not implemented on GPU yet, although contributions are welcome: issue).
Parameters: - x (dynet.Expression) – The first input expression
- y (Expression) – The second input expression
Returns: \(\text{Tr}(xy)\)
Return type:
-
dynet.
logdet
(x)¶ Log determinant
Takes the log of the determinant of a matrix. (not implemented on GPU yet, although contributions are welcome: issue).
Parameters: x (dynet.Expression) – Input expression Returns: \(\log(\vert x\vert)\) Return type: dynet.Expression
Convolution/Pooling operations¶
-
dynet.
conv2d
(x, f, stride, is_valid=True)¶ 2D convolution without bias
2D convolution operator without bias parameters.
VALID
andSAME
convolutions are supported.Think about when stride is 1, the distinction:
SAME
: output size is the same with input size. To do so, one needs to pad the input so the filter can sweep outside of the input maps.VALID
: output size shrinks byfilter_size - 1
, and the filters always sweep at valid positions inside the input maps. No padding needed.
In detail, assume
- Input feature maps:
XH x XW x XC x N
- Filters:
FH x FW x XC x FC
- Strides:
strides[0]
andstrides[1]
are row (h
) and col (w
) stride, respectively.
For the
SAME
convolution: the output height (YH
) and width (YW
) are computed as:YH = ceil(float(XH) / float(strides[0]))
YW = ceil(float(XW) / float(strides[1]))
and the paddings are computed as:
pad_along_height = max((YH - 1) * strides[0] + FH - XH, 0)
pad_along_width = max((YW - 1) * strides[1] + FW - XW, 0)
pad_top = pad_along_height / 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width / 2
pad_right = pad_along_width - pad_left
For the
VALID
convolution: the output height (:code`YH`) and width (YW
) are computed as:YH = ceil(float(XH - FH + 1) / float(strides[0]))
YW = ceil(float(XW - FW + 1) / float(strides[1]))
and the paddings are always zeros.
Parameters: - x (dynet.Expression) – The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimension
- f (dynet.Expression) – 2D convolution filters: H x W x Ci x Co (ColMaj), 4D tensor
- stride (list) – the row and column strides
Keyword Arguments: is_valid (bool) – ‘VALID’ convolution or ‘SAME’ convolution, default is True (‘VALID’) (default: (True))
Returns: The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
Return type:
-
dynet.
conv2d_bias
(x, f, b, stride, is_valid=True)¶ 2D convolution with bias
2D convolution operator with bias parameters.
VALID
andSAME
convolutions are supported.Think about when stride is 1, the distinction:
SAME
: output size is the same with input size. To do so, one needs to pad the input so the filter can sweep outside of the input maps.VALID
: output size shrinks byfilter_size - 1
, and the filters always sweep at valid positions inside the input maps. No padding needed.
In detail, assume
- Input feature maps:
XH x XW x XC x N
- Filters:
FH x FW x XC x FC
- Strides:
strides[0]
andstrides[1]
are row (h
) and col (w
) stride, respectively.
For the
SAME
convolution: the output height (YH
) and width (YW
) are computed as:YH = ceil(float(XH) / float(strides[0]))
YW = ceil(float(XW) / float(strides[1]))
and the paddings are computed as:
pad_along_height = max((YH - 1) * strides[0] + FH - XH, 0)
pad_along_width = max((YW - 1) * strides[1] + FW - XW, 0)
pad_top = pad_along_height / 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width / 2
pad_right = pad_along_width - pad_left
For the
VALID
convolution: the output height (:code`YH`) and width (YW
) are computed as:YH = ceil(float(XH - FH + 1) / float(strides[0]))
YW = ceil(float(XW - FW + 1) / float(strides[1]))
and the paddings are always zeros.
Parameters: - x (dynet.Expression) – The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimension
- f (dynet.Expression) – 2D convolution filters: H x W x Ci x Co (ColMaj), 4D tensor
- b (dynet.Expression) – The bias (1D: Ci)
- stride (list) – the row and column strides
Keyword Arguments: is_valid (bool) – ‘VALID’ convolution or ‘SAME’ convolution, default is True (‘VALID’) (default: (True))
Returns: The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
Return type:
-
dynet.
filter1d_narrow
(x, y)¶ [summary]
[description]
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: TODO
Return type:
-
dynet.
kmax_pooling
(x, k, d=1)¶ Kmax-pooling operation
Select out k maximum values along a given dimension, in the same order as they appear. This will result in the size of the given dimension being changed to k.
Parameters: - x (dynet.Expression) –
- k (unsigned) – Number of maximum values to retrieve along the given dimension
Keyword Arguments: d (unsigned) – Dimension on which to perform kmax-pooling (default: (1))
Returns: Return type:
Tensor operations¶
-
dynet.
contract3d_1d
(x, y)¶ Contracts a rank 3 tensor and a rank 1 tensor into a rank 2 tensor
The resulting tensor \(z\) has coordinates \(z_ij = \sum_k x_{ijk} y_k\)
Parameters: - x (dynet.Expression) – Rank 3 tensor
- y (dynet.Expression) – Vector
Returns: Matrix dynet.Expression
-
dynet.
contract3d_1d_bias
(x, y, b)¶ Same as
contract3d_1d
with an additional bias parameterThe resulting tensor \(z\) has coordinates \(z_{ij} = b_{ij}+\sum_k x_{ijk} y_k\)
Parameters: - x (dynet.Expression) – Rank 3 tensor
- y (dynet.Expression) – Vector
- b (dynet.Expression) – Bias vector
Returns: Matrix dynet.Expression
-
dynet.
contract3d_1d_1d
(x, y, z)¶ Contracts a rank 3 tensor and two rank 1 tensor into a rank 1 tensor
This is the equivalent of calling
contract3d_1d
and then performing a matrix vector multiplication.The resulting tensor \(t\) has coordinates \(t_i = \sum_{j,k} x_{ijk} y_k z_j\)
Parameters: - x (dynet.Expression) – Rank 3 tensor
- y (dynet.Expression) – Vector
- z (dynet.Expression) – Vector
Returns: Vector dynet.Expression
-
dynet.
contract3d_1d_1d_bias
(x, y, z, b)¶ Same as
contract3d_1d_1d
with an additional bias parameterThis is the equivalent of calling
contract3d_1d
and then performing an affine transform.The resulting tensor \(t\) has coordinates \(t_i = b_i + \sum_{j,k} x_{ijk} y_k z_j\)
Parameters: - x (dynet.Expression) – Rank 3 tensor
- y (dynet.Expression) – Vector
- z (dynet.Expression) – Vector
- b (dynet.Expression) – Bias vector
Returns: Vector dynet.Expression
Normalization operations¶
-
dynet.
layer_norm
(x, g, b)¶ Layer normalization
Performs layer normalization :
\[\begin{split}\begin{split} \mu &= \frac 1 n \sum_{i=1}^n x_i\\ \sigma &= \sqrt{\frac 1 n \sum_{i=1}^n (x_i-\mu)^2}\\ y&=\frac {\boldsymbol{g}} \sigma \circ (\boldsymbol{x}-\mu) + \boldsymbol{b}\\ \end{split}\end{split}\]Reference : Ba et al., 2016
Parameters: - x (dynet.Expression) – Input expression (possibly batched)
- g (dynet.Expression) – Gain (same dimension as x, no batch dimension)
- b (dynet.Expression) – Bias (same dimension as x, no batch dimension)
Returns: An expression of the same dimension as
x
dynet.Expression
Recurrent Neural Networks¶
RNN Builders¶
-
class
dynet.
_RNNBuilder
¶ -
disable_dropout
()¶ [summary]
[description]
-
initial_state
(vecs=None)¶ Get a
dynet.RNNState
This initializes a
dynet.RNNState
by loading the parameters in the computation graphParameters: vecs (list) – Initial hidden state for each layer as a list of dynet.Expression
s (default: {None})Returns: dynet.RNNState
used to feed inputs/transduces sequences, etc... dynet.RNNState
-
initial_state_from_raw_vectors
(vecs=None)¶ Get a
dynet.RNNState
This initializes a
dynet.RNNState
by loading the parameters in the computation graphUse this if you want to initialize the hidden states with values directly rather than expressions.
Parameters: vecs (list) – Initial hidden state for each layer as a list of numpy arrays (default: {None}) Returns: dynet.RNNState
used to feed inputs/transduces sequences, etc... dynet.RNNState
-
set_dropout
(f)¶ [summary]
[description]
Parameters: f (float) – [description]
-
-
class
dynet.
SimpleRNNBuilder
¶ Bases:
dynet._RNNBuilder
[summary]
[description]
-
class
dynet.
GRUBuilder
¶ Bases:
dynet._RNNBuilder
[summary]
[description]
-
class
dynet.
LSTMBuilder
¶ Bases:
dynet._RNNBuilder
[summary]
[description]
-
class
dynet.
VanillaLSTMBuilder
¶ Bases:
dynet._RNNBuilder
VanillaLSTM allows to create an “standard” LSTM, ie with decoupled input and forget gate and no peepholes connections
This cell runs according to the following dynamics :
\[\begin{split}\begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+b_i)\\ f_t & = \sigma(W_{fx}x_t+W_{fh}h_{t-1}+b_f+1)\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split}\end{split}\]Parameters: - layers (int) – Number of layers
- input_dim (int) – Dimension of the input
- hidden_dim (int) – Dimension of the recurrent units
- model (dynet.Model) – Model to hold the parameters
- ln_lstm (bool) – Whether to use layer normalization
-
set_dropout_masks
(batch_size=1)¶ Set dropout masks at the beginning of a sequence for a specific batch size
If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element
You need to call this __AFTER__ calling initial_state
Parameters: batch_size (int) – Batch size (default: {1})
-
set_dropouts
(d, d_r)¶ Set the dropout rates
The dropout implemented here is the variational dropout with tied weights introduced in Gal, 2016
More specifically, dropout masks \(\mathbf{z_x}\sim \text(1-d_x)\), \(\mathbf{z_h}\sim \text{Bernoulli}(1-d_h)\) are sampled at the start of each sequence.
The dynamics of the cell are then modified to :
\[\begin{split}\begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ih}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_i)\\ f_t & = \sigma(W_{fx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{fh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_f)\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{oh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_o)\\ \tilde{c_t} & = anh(W_{cx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ch}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split}\end{split}\]For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation
Parameters: - d (number) – Dropout rate \(d_x\) for the input \(x_t\)
- d_r (number) – Dropout rate \(d_x\) for the output \(h_t\)
-
class
dynet.
FastLSTMBuilder
¶ Bases:
dynet._RNNBuilder
[summary]
[description]
-
class
dynet.
BiRNNBuilder
(num_layers, input_dim, hidden_dim, model, rnn_builder_factory, builder_layers=None)¶ Bases:
object
Builder for BiRNNs that delegates to regular RNNs and wires them together.
builder = BiRNNBuilder(1, 128, 100, model, LSTMBuilder) [o1,o2,o3] = builder.transduce([i1,i2,i3])-
add_inputs
(es)¶ returns the list of state pairs (stateF, stateB) obtained by adding inputs to both forward (stateF) and backward (stateB) RNNs. :param es: a list of Expression :type es: list
see also transduce(xs)
code:.transduce(xs) is different from .add_inputs(xs) in the following way:
- code:.add_inputs(xs) returns a list of RNNState pairs. RNNState objects can be
- queried in various ways. In particular, they allow access to the previous state, as well as to the state-vectors (h() and s() )
.transduce(xs)
returns a list of Expression. These are just the output- expressions. For many cases, this suffices. transduce is much more memory efficient than add_inputs.
-
transduce
(es)¶ returns the list of output Expressions obtained by adding the given inputs to the current state, one by one, to both the forward and backward RNNs, and concatenating.
@param es: a list of Expression
see also add_inputs(xs)
.transduce(xs) is different from .add_inputs(xs) in the following way:
- .add_inputs(xs) returns a list of RNNState pairs. RNNState objects can be
- queried in various ways. In particular, they allow access to the previous state, as well as to the state-vectors (h() and s() )
- .transduce(xs) returns a list of Expression. These are just the output
- expressions. For many cases, this suffices. transduce is much more memory efficient than add_inputs.
-
RNN state¶
-
class
dynet.
RNNState
¶ This is the main class for working with RNNs / LSTMs / GRUs. Request an RNNState initial_state() from a builder, and then progress from there.
-
add_input
(x)¶ This computes \(h_t = \text{RNN}(x_t)\)
Parameters: x (dynet.Expression) – Input expression Returns: New RNNState dynet.RNNState
-
add_inputs
(xs)¶ Returns the list of states obtained by adding the given inputs to the current state, one by one.
see also
transduce(xs)
.transduce(xs)
is different from.add_inputs(xs)
in the following way:.add_inputs(xs)
returns a list of RNNState. RNNState objects can be- queried in various ways. In particular, they allow access to the previous
state, as well as to the state-vectors (
h()
ands()
)
.transduce(xs)
returns a list of Expression. These are just the output- expressions. For many cases, this suffices.
transduce
is much more memory efficient thanadd_inputs
.Parameters: xs (list) – list of input expressions Returns: New RNNState dynet.RNNState
-
b
()¶ Get the underlying RNNBuilder
In case you need to set dropout or other stuff.
Returns: Underlying RNNBuilder dynet.RNNBuilder
-
h
()¶ tuple of expressions representing the output of each hidden layer of the current step. the actual output of the network is at h()[-1].
-
prev
()¶ Gets previous RNNState
In case you need to rewind
-
s
()¶ tuple of expressions representing the hidden state of the current step.
For SimpleRNN, s() is the same as h() For LSTM, s() is a series of of memory vectors, followed the series followed by the series returned by h().
-
set_h
(es=None)¶ Manually set the output \(h_t\)
Parameters: es (list) – List of expressions, one for each layer (default: {None}) Returns: New RNNState dynet.RNNState
-
set_s
(es=None)¶ Manually set the hidden states
This is different from
set_h
because, for LSTMs for instance this also sets the cell state. The format is[new_c[0],...,new_c[n],new_h[0],...,new_h[n]]
Parameters: es (list) – List of expressions, in this format : [new_c[0],...,new_c[n],new_h[0],...,new_h[n]]
(default: {None})Returns: New RNNState dynet.RNNState
-
transduce
(xs)¶ returns the list of output Expressions obtained by adding the given inputs to the current state, one by one.
see also
add_inputs(xs)
.transduce(xs)
is different from.add_inputs(xs)
in the following way:.add_inputs(xs)
returns a list of RNNState. RNNState objects can be- queried in various ways. In particular, they allow access to the previous
state, as well as to the state-vectors (
h()
ands()
)
.transduce(xs)
returns a list of Expression. These are just the output- expressions. For many cases, this suffices.
transduce
is much more memory efficient thanadd_inputs
.Parameters: xs (list) – list of input expressions Returns: New RNNState dynet.RNNState
-
Optimizers¶
-
class
dynet.
Trainer
¶ Generic trainer
-
get_clip_threshold
()¶ Get clipping threshold
Returns: Gradient clipping threshold Return type: number
-
set_clip_threshold
(thr)¶ Set clipping thershold
To deactivate clipping, set the threshold to be <=0
Parameters: thr (number) – Clipping threshold
-
set_sparse_updates
(su)¶ Sets updates to sparse updates
DyNet trainers support two types of updates for lookup parameters, sparse and dense. Sparse updates are the default. They have the potential to be faster, as they only touch the parameters that have non-zero gradients. However, they may not always be faster (particulary on GPU with mini-batch training), and are not precisely numerically correct for some update rules such as MomentumTrainer and AdamTrainer. Thus, if you set this variable to false, the trainer will perform dense updates and be precisely correct, and maybe faster sometimes. :param su: flag to activate/deactivate sparse updates :type su: bool
-
status
()¶ Outputs information about the trainer in the stderr
(number of updates since last call, number of clipped gradients, learning rate, etc...)
-
update
(s=1.0)¶ Update the parameters
The update equation is different for each trainer, check the online c++ documentation for more details on what each trainer does
Keyword Arguments: s (number) – Optional scaling factor to apply on the gradient. (default: 1.0)
-
update_epoch
(r=1.0)¶ Update trainers hyper-parameters that depend on epochs
Basically learning rate decay.
Keyword Arguments: r (number) – Number of epoch that passed (default: 1.0)
-
update_subset
(updated_params, updated_lookups, s=1.0)¶ Update a subset of parameters
Only use this in last resort, a more elegant way to update only a subset of parameters is to use the “update” keyword in dy.parameter or Parameter.expr() to specify which parameters need to be updated __during the creation of the computation graph__
Parameters: - updated_params (list) – Indices of parameters to update
- updated_lookups (list) – Indices of lookup parameters to update
Keyword Arguments: s (number) – Optional scaling factor to apply on the gradient. (default: 1.0)
-
-
class
dynet.
SimpleSGDTrainer
¶ Bases:
dynet.Trainer
Stochastic gradient descent trainer
This trainer performs stochastic gradient descent, the goto optimization procedure for neural networks.
Parameters: m (dynet.Model) – Model to be trained
Keyword Arguments: - e0 (number) – Initial learning rate (default: 0.1)
- edecay (number) – Learning rate decay parameter (default: 0.0)
-
class
dynet.
CyclicalSGDTrainer
¶ Bases:
dynet.Trainer
This trainer performs stochastic gradient descent with a cyclical learning rate as proposed in Smith, 2015.
This uses a triangular function with optional exponential decay.
More specifically, at each update, the learning rate \(\eta\) is updated according to :
\[\begin{split} \begin{split} \text{cycle} &= \left\lfloor 1 + \frac{\texttt{it}}{2 \times\texttt{step_size}} \right\rfloor\\ x &= \left\vert \frac{\texttt{it}}{\texttt{step_size}} - 2 \times \text{cycle} + 1\right\vert\\ \eta &= \eta_{\text{min}} + (\eta_{\text{max}} - \eta_{\text{min}}) \times \max(0, 1 - x) \times \gamma^{\texttt{it}}\\ \end{split}\end{split}\]Parameters: m (dynet.Model) – Model to be trained
Keyword Arguments: - e0_min (number) – Lower learning rate (default: {0.01})
- e0_max (number) – Upper learning rate (default: {0.1})
- step_size (number) – Period of the triangular function in number of iterations (__not__ epochs). According to the original paper, this should be set around (2-8) x (training iterations in epoch) (default: {2000})
- gamma (number) – Learning rate upper bound decay parameter (default: {0.0})
- edecay (number) – Learning rate decay parameter. Ideally you shouldn’t use this with cyclical learning rate since decay is already handled by \(\gamma\) (default: {0.0})
-
class
dynet.
MomentumSGDTrainer
¶ Bases:
dynet.Trainer
Stochastic gradient descent with momentum
This is a modified version of the SGD algorithm with momentum to stablize the gradient trajectory.
Parameters: m (dynet.Model) – Model to be trained
Keyword Arguments: - e0 (number) – Initial learning rate (default: 0.1)
- mom (number) – Momentum (default: 0.9)
- edecay (number) – Learning rate decay parameter (default: 0.0)
-
class
dynet.
AdagradTrainer
¶ Bases:
dynet.Trainer
Adagrad optimizer
The adagrad algorithm assigns a different learning rate to each parameter.
Parameters: m (dynet.Model) – Model to be trained
Keyword Arguments: - e0 (number) – Initial learning rate (default: 0.1)
- eps (number) – Epsilon parameter to prevent numerical instability (default: 1e-20)
- edecay (number) – Learning rate decay parameter (default: 0.0)
-
class
dynet.
AdadeltaTrainer
¶ Bases:
dynet.Trainer
AdaDelta optimizer
The AdaDelta optimizer is a variant of Adagrad aiming to prevent vanishing learning rates.
Parameters: m (dynet.Model) – Model to be trained
Keyword Arguments: - eps (number) – Epsilon parameter to prevent numerical instability (default: 1e-6)
- rho (number) – Update parameter for the moving average of updates in the numerator (default: 0.95)
- edecay (number) – Learning rate decay parameter (default: 0.0)
-
class
dynet.
RMSPropTrainer
¶ Bases:
dynet.Trainer
RMSProp optimizer
The RMSProp optimizer is a variant of Adagrad where the squared sum of previous gradients is replaced with a moving average with parameter rho.
Parameters: m (dynet.Model) – Model to be trained
Keyword Arguments: - e0 (number) – Initial learning rate (default: 0.001)
- eps (number) – Epsilon parameter to prevent numerical instability (default: 1e-8)
- rho (number) – Update parameter for the moving average (rho = 0 is equivalent to using Adagrad) (default: 0.9)
- edecay (number) – Learning rate decay parameter (default: 0.0)
-
class
dynet.
AdamTrainer
¶ Bases:
dynet.Trainer
Adam optimizer
The Adam optimizer is similar to RMSProp but uses unbiased estimates of the first and second moments of the gradient
Parameters: m (dynet.Model) – Model to be trained
Keyword Arguments: - alpha (number) – Initial learning rate (default: 0.001)
- beta_1 (number) – Moving average parameter for the mean (default: 0.9)
- beta_2 (number) – Moving average parameter for the variance (default: 0.999)
- eps (number) – Epsilon parameter to prevent numerical instability (default: 1e-8)
- edecay (number) – Learning rate decay parameter (default: 0.0)
C++ Reference manual¶
Core functionalities¶
Computation Graph¶
The ComputationGraph is the workhorse of dynet. From the Dynet technical report :
[The] computation graph represents symbolic computation, and the results of the computation are evaluated lazily: the computation is only performed once the user explicitly asks for it (at which point a “forward” computation is triggered). Expressions that evaluate to scalars (i.e. loss values) can also be used to trigger a “backward” computation, computing the gradients of the computation with respect to the parameters.
-
int
dynet::
get_number_of_active_graphs
()¶ Gets the number of active graphs.
This is 0 or 1, you can’t create more than one graph at once
- Return
- Number of active graphs
-
unsigned
dynet::
get_current_graph_id
()¶ Get id of the current active graph.
This can help check whether a graph is stale
- Return
- Id of the current graph
-
struct
dynet::
ComputationGraph
¶ - #include <dynet.h>
Computation graph where nodes represent forward and backward intermediate values, and edges represent functions of multiple values.
To represent the fact that a function may have multiple arguments, edges have a single head and 0, 1, 2, or more tails. (Constants, inputs, and parameters are represented as functions of 0 parameters.) Example: given the function z = f(x, y), z, x, and y are nodes, and there is an edge representing f with which points to the z node (i.e., its head), and x and y are the tails of the edge. You shouldn’t need to use most methods from the ComputationGraph except for
backward
since most of them are available directly from the Expression class.Public Functions
-
ComputationGraph
()¶ Default constructor.
-
VariableIndex
add_input
(real s)¶ Add scalar input.
The computational network will pull inputs in from the user’s data structures and make them available to the computation
- Return
- The index of the created variable
- Parameters
s
: Real number
-
VariableIndex
add_input
(const real *ps)¶ Add scalar input by pointer.
The computational network will pull inputs in from the user’s data structures and make them available to the computation
- Return
- The index of the created variable
- Parameters
ps
: Pointer to a real number
-
VariableIndex
add_input
(const Dim &d, const std::vector<float> &data)¶ Add multidimentsional input.
The computational network will pull inputs in from the user’s data structures and make them available to the computation
- Return
- The index of the created variable
- Parameters
d
: Desired shape of the inputdata
: Input data (as a 1 dimensional array)
-
VariableIndex
add_input
(const Dim &d, const std::vector<float> *pdata)¶ Add multidimentsional input by pointer.
The computational network will pull inputs in from the user’s data structures and make them available to the computation
- Return
- The index of the created variable
- Parameters
d
: Desired shape of the inputpdata
: Pointer to the input data (as a 1 dimensional array)
-
VariableIndex
add_input
(const Dim &d, const std::vector<unsigned int> &ids, const std::vector<float> &data, float defdata = 0.f)¶ Add sparse input.
The computational network will pull inputs in from the user’s data structures and make them available to the computation. Represents specified (not learned) inputs to the network in sparse array format, with an optional default value.
- Return
- The index of the created variable
- Parameters
d
: Desired shape of the inputids
: The indexes of the data points to updatedata
: The data points corresponding to each indexdefdata
: The default data with which to set the unspecified data points
-
VariableIndex
add_parameters
(Parameter p)¶ Add a parameter to the computation graph.
- Return
- The index of the created variable
- Parameters
p
: Parameter to be added
-
VariableIndex
add_parameters
(LookupParameter p)¶ Add a full matrix of lookup parameters to the computation graph.
- Return
- The index of the created variable
- Parameters
p
: LookupParameter to be added
-
VariableIndex
add_const_parameters
(Parameter p)¶ Add a parameter to the computation graph (but don’t update)
- Return
- The index of the created variable
- Parameters
p
: Parameter to be added
-
VariableIndex
add_const_parameters
(LookupParameter p)¶ Add a full matrix of lookup parameter to the computation graph (but don’t update)
- Return
- The index of the created variable
- Parameters
p
: LookupParameter to be added
-
VariableIndex
add_lookup
(LookupParameter p, const unsigned *pindex)¶ Add a lookup parameter to the computation graph.
Use pindex to point to a memory location where the index will live that the caller owns
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickpindex
: Pointer to the index to lookup
-
VariableIndex
add_lookup
(LookupParameter p, unsigned index)¶ Add a lookup parameter to the computation graph.
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickindex
: Index to lookup
-
VariableIndex
add_lookup
(LookupParameter p, const std::vector<unsigned> *pindices)¶ Add lookup parameters to the computation graph.
Use pindices to point to a memory location where the indices will live that the caller owns
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickpindices
: Pointer to the indices to lookup
-
VariableIndex
add_lookup
(LookupParameter p, const std::vector<unsigned> &indices)¶ Add lookup parameters to the computation graph.
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickindices
: Indices to lookup
-
VariableIndex
add_const_lookup
(LookupParameter p, const unsigned *pindex)¶ Add a lookup parameter to the computation graph.
Just like add_lookup, but don’t optimize the lookup parameters
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickpindex
: Pointer to the indices to lookup
-
VariableIndex
add_const_lookup
(LookupParameter p, unsigned index)¶ Add a lookup parameter to the computation graph.
Just like add_lookup, but don’t optimize the lookup parameters
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickindex
: Index to lookup
-
VariableIndex
add_const_lookup
(LookupParameter p, const std::vector<unsigned> *pindices)¶ Add lookup parameters to the computation graph.
Just like add_lookup, but don’t optimize the lookup parameters
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickpindices
: Pointer to the indices to lookup
-
VariableIndex
add_const_lookup
(LookupParameter p, const std::vector<unsigned> &indices)¶ Add lookup parameters to the computation graph.
Just like add_lookup, but don’t optimize the lookup parameters
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickindices
: Indices to lookup
- template <class Function>
-
VariableIndex
add_function
(const std::initializer_list<VariableIndex> &arguments)¶ Add a function to the computation graph.
This what is called when creating an expression
- Return
- The index of the output variable
- Parameters
arguments
: List of the arguments indices
- Template Parameters
Function
: Function to be applied
- template <class Function, typename... Args>
-
VariableIndex
add_function
(const std::initializer_list<VariableIndex> &arguments, Args&&... side_information)¶ Add a function to the computation graph (with side information)
This what is called when creating an expression
- Return
- The index of the output variable
- Parameters
arguments
: List of the arguments indicesside_information
: Side information that is needed to compute the function
- Template Parameters
Function
: Function to be applied
-
void
clear
()¶ Reset ComputationGraph to a newly created state.
[long description]
-
void
checkpoint
()¶ Set a checkpoint.
-
void
revert
()¶ Revert to last checkpoint.
-
Dim &
get_dimension
(VariableIndex index) const¶ Get dimension of a node.
- Return
- Dimension
- Parameters
index
: Variable index of the node
-
const Tensor &
forward
(const expr::Expression &last)¶ Run complete forward pass from first node to given one, ignoring all precomputed values.
- Return
- Value of the
last
Expression after execution - Parameters
last
: Expression up to which the forward pass must be computed
-
const Tensor &
forward
(VariableIndex i)¶ Run complete forward pass from first node to given one, ignoring all precomputed values.
- Return
- Value of the end Node after execution
- Parameters
i
: Variable index of the node up to which the forward pass must be computed
-
const Tensor &
incremental_forward
(const expr::Expression &last)¶ Run forward pass from the last computed node to given one.
Useful if you want to add nodes and evaluate just the new parts.
- Return
- Value of the
last
Expression after execution - Parameters
last
: Expression up to which the forward pass must be computed
-
const Tensor &
incremental_forward
(VariableIndex i)¶ Run forward pass from the last computed node to given one.
Useful if you want to add nodes and evaluate just the new parts.
- Return
- Value of the end Node after execution
- Parameters
last
: Variable index of the node up to which the forward pass must be computed
-
const Tensor &
get_value
(VariableIndex i)¶ Get forward value for node at index i.
Performs forward evaluation if note available (may compute more than strictly what is needed).
- Return
- Requested value
- Parameters
i
: Index of the variable from which you want the value
-
const Tensor &
get_value
(const expr::Expression &e)¶ Get forward value for the given expression.
Performs forward evaluation if note available (may compute more than strictly what is needed).
- Return
- Requested value
- Parameters
e
: Expression from which you want the value
-
const Tensor &
get_gradient
(VariableIndex i)¶ Get gradient for node at index i.
Performs backward pass if not available (may compute more than strictly what is needed).
- Return
- Requested gradient
- Parameters
i
: Index of the variable from which you want the gradient
-
const Tensor &
get_gradient
(const expr::Expression &e)¶ Get forward gradient for the given expression.
Performs backward pass if not available (may compute more than strictly what is needed).
- Return
- Requested gradient
- Parameters
e
: Expression from which you want the gradient
-
void
invalidate
()¶ Clears forward caches (for get_value etc).
-
void
backward
(const expr::Expression &last, bool full = false)¶ Computes backward gradients from the front-most evaluated node.
The parameter
full
specifies whether the gradients should be computed for all nodes (true
) or only non-constant nodes.By default, a node is constant unless
- it is a parameter node
- it depends on a non-constant node
Thus, functions of constants and inputs are considered as constants.
Turn
full
on if you want to retrieve gradients w.r.t. inputs for instance. By default this is turned off, so that the backward pass ignores nodes which have no influence on gradients w.r.t. parameters for efficiency.- Parameters
last
: Expression from which to compute the gradientfull
: Whether to compute all gradients (including with respect to constant nodes).
-
void
backward
(VariableIndex i, bool full = false)¶ Computes backward gradients from node i (assuming it already been evaluated).
The parameter
full
specifies whether the gradients should be computed for all nodes (true
) or only non-constant nodes.By default, a node is constant unless
- it is a parameter node
- it depends on a non-constant node
Thus, functions of constants and inputs are considered as constants.
Turn
full
on if you want to retrieve gradients w.r.t. inputs for instance. By default this is turned off, so that the backward pass ignores nodes which have no influence on gradients w.r.t. parameters for efficiency.- Parameters
i
: Index of the node from which to compute the gradientfull
: Whether to compute all gradients (including with respect to constant nodes). Turn this on if you want to retrieve gradients w.r.t. inputs for instance. By default this is turned off, so that the backward pass ignores nodes which have no influence on gradients w.r.t. parameters for efficiency.
-
void
print_graphviz
() const¶ Used for debugging.
-
unsigned
get_id
() const¶ Get the unique graph ID.
This ID is incremented by 1 each time a computation graph is created
- Return
- graph is
-
Nodes¶
Nodes are constituents of the computation graph. The end user doesn’t interact with Nodes but with Expressions.
However implementing new operations requires to create a new subclass of the Node class described below.
-
struct
dynet::
Node
¶ - #include <dynet.h>
Represents an SSA variable.
Contains information on tha computation node : arguments, output value and gradient of the output with respect to the function. This class must be inherited to implement any new operation. See nodes.cc for examples. An operation on expressions can then be created from the new Node, see expr.h/expr.cc for examples
Subclassed by dynet::Abs, dynet::AddVectorToAllColumns, dynet::AffineTransform, dynet::Average, dynet::AverageColumns, dynet::BinaryLogLoss, dynet::BlockDropout, dynet::Concatenate, dynet::ConcatenateToBatch, dynet::ConstantMinusX, dynet::ConstantPlusX, dynet::ConstParameterNode, dynet::ConstScalarMultiply, dynet::Conv2D, dynet::Cube, dynet::CwiseMultiply, dynet::CwiseQuotient, dynet::DotProduct, dynet::Dropout, dynet::Erf, dynet::Exp, dynet::Filter1DNarrow, dynet::FlipGradient, dynet::FoldRows, dynet::GaussianNoise, dynet::Hinge, dynet::HuberDistance, dynet::Identity, dynet::InnerProduct3D_1D, dynet::InnerProduct3D_1D_1D, dynet::InputNode, dynet::KMaxPooling, dynet::KMHNGram, dynet::L1Distance, dynet::Log, dynet::LogDet, dynet::LogGamma, dynet::LogisticSigmoid, dynet::LogSoftmax, dynet::LogSumExp, dynet::MatrixInverse, dynet::MatrixMultiply, dynet::Max, dynet::MaxDimension, dynet::MaxPooling1D, dynet::Min, dynet::MinDimension, dynet::Negate, dynet::NoBackprop, dynet::PairwiseRankLoss, dynet::ParameterNodeBase, dynet::PickBatchElements, dynet::PickElement, dynet::PickNegLogSoftmax, dynet::PickRange, dynet::PoissonRegressionLoss, dynet::Pow, dynet::RandomBernoulli, dynet::RandomGumbel, dynet::RandomNormal, dynet::RandomUniform, dynet::Rectify, dynet::Reshape, dynet::RestrictedLogSoftmax, dynet::ScalarAdd, dynet::ScalarInputNode, dynet::ScalarMultiply, dynet::ScalarQuotient, dynet::SelectCols, dynet::SelectRows, dynet::Softmax, dynet::SoftSign, dynet::SparseInputNode, dynet::Sparsemax, dynet::SparsemaxLoss, dynet::Sqrt, dynet::Square, dynet::SquaredEuclideanDistance, dynet::SquaredNorm, dynet::Sum, dynet::SumBatches, dynet::SumDimension, dynet::SumElements, dynet::Tanh, dynet::TraceOfProduct, dynet::Transpose, dynet::Zeroes
Public Functions
-
virtual Dim
dim_forward
(const std::vector<Dim> &xs) const = 0¶ Compute dimensions of result for given dimensions of inputs.
Also checks to make sure inputs are compatible with each other
- Return
- Dimension of the output
- Parameters
xs
: Vector containing the dimensions of the inputs
-
virtual std::string
as_string
(const std::vector<std::string> &args) const = 0¶ Returns important information for debugging.
See nodes-conv.cc for examples
- Return
- String description of the node
- Parameters
args
: String descriptions of the arguments
-
size_t
aux_storage_size
() const¶ Size of the auxiliar storage.
in general, this will return an empty size, but if a component needs to store extra information in the forward pass for use in the backward pass, it can request the memory here (nb. you could put it on the Node object, but in general, edges should not allocate tensor memory since memory is managed centrally for the entire computation graph).
- Return
- Size
-
virtual void
forward_impl
(const std::vector<const Tensor *> &xs, Tensor &fx) const = 0¶ Forward computation.
This function contains the logic for the forward pass. Some implementation remarks from nodes.cc:
- fx can be understood as a pointer to the (preallocated) location for the result of forward to be stored
- fx is not initialized, so after calling forward fx must point to the correct answer
- fx can be repointed to an input, if forward(x) evaluates to x (e.g., in reshaping)
- scalars results of forward are placed in fx.v[0]
- DYNET manages its own memory, not Eigen, and it is configured with the EIGEN_NO_MALLOC option. If you get an error about Eigen attempting to allocate memory, it is (probably) because of an implicit creation of a temporary variable. To tell Eigen this is not necessary, the noalias() method is available. If you really do need a temporary variable, its capacity must be requested by Node::aux_storage_size
Note on debugging problems with differentiable components
- fx is uninitialized when forward is called- are you relying on it being 0?
- Parameters
xs
: Pointers to the inputsfx
: pointer to the (preallocated) location for the result of forward to be stored
-
virtual void
backward_impl
(const std::vector<const Tensor *> &xs, const Tensor &fx, const Tensor &dEdf, unsigned i, Tensor &dEdxi) const = 0¶ Accumulates the derivative of E with respect to the ith argument to f, that is, xs[i].
This function contains the logic for the backward pass. Some implementation remarks from nodes.cc:
- dEdxi MUST ACCUMULATE a result since multiple calls to forward may depend on the same x_i. Even, e.g., Identity must be implemented as dEdx1 += dEdf. THIS IS EXTREMELY IMPORTANT
- scalars results of forward are placed in fx.v[0]
- DYNET manages its own memory, not Eigen, and it is configured with the EIGEN_NO_MALLOC option. If you get an error about Eigen attempting to allocate memory, it is (probably) because of an implicit creation of a temporary variable. To tell Eigen this is not necessary, the noalias() method is available. If you really do need a temporary variable, its capacity must be requested by Node::aux_storage_size
Note on debugging problems with differentiable components
- dEdxi must accummulate (see point 4 above!)
- Parameters
xs
: Pointers to inputsfx
: OutputdEdf
: Gradient of the objective w.r.t the output of the nodei
: Index of the input w.r.t which we take the derivativedEdxi
: Gradient of the objective w.r.t the input of the node
-
virtual bool
supports_multibatch
() const¶ Whether this node supports computing multiple batches in one call.
If true, forward and backward will be called once with a multi-batch tensor. If false, forward and backward will be called multiple times for each item.
- Return
- Support for multibatch
-
void
forward
(const std::vector<const Tensor *> &xs, Tensor &fx) const¶ perform the forward/backward passes in one or multiple calls
- Parameters
xs
: Pointers to the inputsfx
: pointer to the (preallocated) location for the result of forward to be stored
-
void
backward
(const std::vector<const Tensor *> &xs, const Tensor &fx, const Tensor &dEdf, unsigned i, Tensor &dEdxi) const¶ perform the backward passes in one or multiple calls
- Parameters
xs
: Pointers to inputsfx
: OutputdEdf
: Gradient of the objective w.r.t the output of the nodei
: Index of the input w.r.t which we take the derivativedEdxi
: Gradient of the objective w.r.t the input of the node
-
unsigned
arity
() const¶ Number of arguments to the function.
- Return
- Arity of the function
Public Members
-
std::vector<VariableIndex>
args
¶ Dependency structure
-
Device *
device
¶ pointer to the node, or null to inherit device from first input, or default when there is no input
-
void *
aux_mem
¶ this will usually be null. but, if your node needs to store intermediate values between forward and backward, you can use store it here. request the number of bytes you need from aux_storage_size(). Note: this memory will be on the CPU or GPU, depending on your computation backend
-
virtual Dim
Parameters and Model¶
Parameters are things that are optimized. in contrast to a system like Torch where computational modules may have their own parameters, in DyNet parameters are just parameters.
To deal with sparse updates, there are two parameter classes:
- Parameters represents a vector, matrix, (eventually higher order tensors) of parameters. These are densely updated.
- LookupParameters represents a table of vectors that are used to embed a set of discrete objects. These are sparsely updated.
-
struct
dynet::
ParameterStorageBase
¶ - #include <model.h>
This is the base class for ParameterStorage and LookupParameterStorage, the objects handling the actual parameters.
You can access the storage from any Parameter (resp. LookupParameter) class, use it only to do low level manipulations.
Subclassed by dynet::LookupParameterStorage, dynet::ParameterStorage
Public Functions
-
virtual void
scale_parameters
(float a) = 0¶ Scale the parameters.
- Parameters
a
: scale factor
-
virtual void
zero
() = 0¶ Set the parameters to 0.
-
virtual void
squared_l2norm
(float *sqnorm) const = 0¶ Get the parameter squared l2 norm.
- Parameters
sqnorm
: Pointer to the float holding the result
-
virtual void
g_squared_l2norm
(float *sqnorm) const = 0¶ Get the squared l2 norm of the gradient w.r.t. these parameters.
- Parameters
sqnorm
: Pointer to the float holding the result
-
virtual size_t
size
() const = 0¶ Get the size (number of scalar parameters)
- Return
- Number of scalar parameters
-
virtual void
-
struct
dynet::
ParameterStorage
¶ - #include <model.h>
Storage class for Parameters.
Inherits from dynet::ParameterStorageBase
Public Functions
-
void
copy
(const ParameterStorage &val)¶ Copy from another ParameterStorage.
- Parameters
val
: ParameterStorage to copy from
-
void
accumulate_grad
(const Tensor &g)¶ Add a tensor to the gradient.
After this method gets called, g <- g + d
- Parameters
g
: Tensor to add
-
void
clear
()¶ Clear the gradient (set it to 0)
-
void
clip
(float left, float right)¶ Clip the values to the range [left, right].
-
void
-
struct
dynet::
LookupParameterStorage
¶ - #include <model.h>
Storage class for LookupParameters.
Inherits from dynet::ParameterStorageBase
Public Functions
-
void
initialize
(unsigned index, const std::vector<float> &val)¶ Initialize one particular lookup.
- Parameters
index
: Index of the lookput to initializeval
: Values
-
void
copy
(const LookupParameterStorage &val)¶ Copy from another LookupParameterStorage.
- Parameters
val
: Other LookupParameterStorage to copy from
-
void
accumulate_grad
(const Tensor &g)¶ Add a Tensor to the gradient of the whole lookup matrix.
after this
grads<-grads + g
- Parameters
g
: [description]
-
void
accumulate_grad
(unsigned index, const Tensor &g)¶ Add a Tensor to the gradient of one of the lookups.
after this
grads[index]<-grads[index] + g
- Parameters
index
: [description]g
: [description]
-
void
accumulate_grads
(unsigned n, const unsigned *ids_host, const unsigned *ids_dev, float *g)¶ Add tensors to muliple lookups.
After this method gets called,
grads[ids_host[i]] <- grads[ids_host[i]] + g[i*dim.size():(i+1)*dim.size()]
- Parameters
n
: size ofids_host
ids_host
: Indices of the gradients to updateids_dev
: [To be documented] (only for GPU)g
: Values
Public Members
-
std::unordered_set<unsigned>
non_zero_grads
¶ Gradients are sparse, so track which components are nonzero
-
void
-
struct
dynet::
Parameter
¶ - #include <model.h>
Object representing a trainable parameter.
This objects acts as a high level component linking the actual parameter values (ParameterStorage) and the Model. As long as you don’t want to do low level hacks at the ParameterStorage level, this is what you will use.
Public Functions
-
Parameter
()¶ Default constructor.
-
Parameter
(Model *mp, unsigned long index)¶ Constructor.
This is called by the model, you shouldn’t need to use it
- Parameters
mp
: Pointer to th modelindex
: Id of the parameter
-
ParameterStorage *
get
() const¶ Get underlying ParameterStorage object.
- Return
- ParameterStorage holding the parameter values
-
void
zero
()¶ Zero the parameters.
-
void
set_updated
(bool b)¶ Set the parameter as updated.
- Parameters
b
: Update status
-
void
scale
(float s)¶ Scales the parameter (multiplies by
s
)- Parameters
s
: scale
-
bool
is_updated
()¶ Check the update status.
- Return
- Update status
-
void
clip_inplace
(float left, float right)¶ Clip the values of the parameter to the range [left, right] (in place)
-
-
struct
dynet::
LookupParameter
¶ - #include <model.h>
Object representing a trainable lookup parameter.
Public Functions
-
LookupParameterStorage *
get
() const¶ Get underlying LookupParameterStorage object.
- Return
- LookupParameterStorage holding the parameter values
-
void
initialize
(unsigned index, const std::vector<float> &val) const¶ Initialize one particular column.
- Parameters
index
: Index of the column to be initializedval
: [description]
-
void
zero
()¶ Zero the parameters.
-
void
scale
(float s)¶ Scales the parameter (multiplies by
s
)- Parameters
s
: scale
-
void
set_updated
(bool b)¶ Set the parameter as updated.
- Parameters
b
: Update status
-
bool
is_updated
()¶ Check the update status.
- Return
- Update status
-
LookupParameterStorage *
-
struct
dynet::
ParameterInit
¶ - #include <model.h>
Initializers for parameters.
Allows for custom parameter initialization
Subclassed by dynet::ParameterInitConst, dynet::ParameterInitFromFile, dynet::ParameterInitFromVector, dynet::ParameterInitGlorot, dynet::ParameterInitIdentity, dynet::ParameterInitNormal, dynet::ParameterInitSaxe, dynet::ParameterInitUniform
Public Functions
-
ParameterInit
()¶ Default constructor.
-
virtual void
initialize_params
(Tensor &values) const = 0¶ Function called upon initialization.
Whenever you inherit this struct to implement your own custom initializer, this is the function you want to overload to implement your logic.
- Parameters
values
: The tensor to be initialized. You should modify it in-place. See dynet/model.cc for some examples
-
-
struct
dynet::
ParameterInitNormal
¶ - #include <model.h>
Initialize parameters with samples from a normal distribution.
Inherits from dynet::ParameterInit
Public Functions
-
ParameterInitNormal
(float m = 0.0f, float v = 1.0f)¶ Constructor.
- Parameters
m
: Mean of the gaussian distributionv
: Variance of the gaussian distribution (reminder : the variance is the square of the standard deviation)
-
-
struct
dynet::
ParameterInitUniform
¶ - #include <model.h>
Initialize parameters with samples from a uniform distribution.
Inherits from dynet::ParameterInit
Public Functions
-
ParameterInitUniform
(float scale)¶ Constructor for uniform distribution centered on 0.
[long description]Samples parameters from \(mathcal U([-\mathrm{scale},+\mathrm{scale}]\)
- Parameters
scale
: Scale of the distribution
-
ParameterInitUniform
(float l, float r)¶ Constructor for uniform distribution in a specific interval.
[long description]
- Parameters
l
: Lower bound of the intervalr
: Upper bound of the interval
-
-
struct
dynet::
ParameterInitConst
¶ - #include <model.h>
Initialize parameters with a constant value.
Inherits from dynet::ParameterInit
Public Functions
-
ParameterInitConst
(float c)¶ Constructor.
- Parameters
c
: Constant value
-
-
struct
dynet::
ParameterInitIdentity
¶ - #include <model.h>
Initialize as the identity.
This will raise an exception if used on non square matrices
Inherits from dynet::ParameterInit
Public Functions
-
ParameterInitIdentity
()¶ Constructor.
-
-
struct
dynet::
ParameterInitGlorot
¶ - #include <model.h>
Initialize with the methods described in Glorot, 2010
In order to preserve the variance of the forward and backward flow across layers, the parameters \(\theta\) are initialized such that \(\mathrm{Var}(\theta)=\frac 2 {n_1+n_2}\) where \(n_1,n_2\) are the input and output dim. Important note : The underlying distribution is uniform (not gaussian)
Inherits from dynet::ParameterInit
Public Functions
-
ParameterInitGlorot
(bool is_lookup = false, float gain = 1.f)¶ Constructor.
- Parameters
is_lookup
: Boolean value identifying the parameter as a LookupParametergain
: Scaling parameter. In order for the Glorot initialization to be correct, you should ût this equal to \(\frac 1 {f'(0)}\) where \(f\) is your activation function
-
-
struct
dynet::
ParameterInitSaxe
¶ - #include <model.h>
Initializes according to Saxe et al., 2014
Initializes as a random orthogonal matrix (unimplemented for GPU)
Inherits from dynet::ParameterInit
Public Functions
-
ParameterInitSaxe
(float gain = 1.0)¶ Constructor.
-
-
struct
dynet::
ParameterInitFromFile
¶ - #include <model.h>
Initializes from a file.
Useful for reusing weights, etc...
Inherits from dynet::ParameterInit
Public Functions
-
ParameterInitFromFile
(std::string f)¶ Constructor.
- Parameters
f
: File name (format should just be a list of values)
-
-
struct
dynet::
ParameterInitFromVector
¶ - #include <model.h>
Initializes from a
std::vector
of floats.Inherits from dynet::ParameterInit
Public Functions
-
ParameterInitFromVector
(std::vector<float> v)¶ Constructor.
- Parameters
v
: Vector of values to be used
-
-
class
dynet::
Model
¶ - #include <model.h>
This is a collection of parameters.
if you need a matrix of parameters, or a lookup table - ask an instance of this class. This knows how to serialize itself. Parameters know how to track their gradients, but any extra information (like velocity) will live here
Public Functions
-
Model
()¶ Constructor.
-
float
gradient_l2_norm
() const¶ Returns the l2 of your gradient.
Use this to look for gradient vanishing/exploding
- Return
- L2 norm of the gradient
-
void
reset_gradient
()¶ Sets all gradients to zero.
-
Parameter
add_parameters
(const Dim &d, float scale = 0.0f)¶ Add parameters to model and returns Parameter object.
creates a ParameterStorage object holding a tensor of dimension
d
and returns a Parameter object (to be used as input in the computation graph). The coefficients are sampled according to thescale
parameter- Return
- Parameter object to be used in the computation graph
- Parameters
d
: Shape of the parameterscale
: If scale is non-zero, initializes according to \(mathcal U([-\mathrm{scale},+\mathrm{scale}]\), otherwise uses Glorot initialization
-
Parameter
add_parameters
(const Dim &d, const ParameterInit &init)¶ Add parameters with custom initializer.
- Return
- Parameter object to be used in the computation graph
- Parameters
d
: Shape of the parameterinit
: Custom initializer
-
LookupParameter
add_lookup_parameters
(unsigned n, const Dim &d)¶ Add lookup parameter to model.
Same as add_parameters. Initializes with Glorot
- Return
- LookupParameter object to be used in the computation graph
- Parameters
n
: Number of lookup indicesd
: Dimension of each embedding
-
LookupParameter
add_lookup_parameters
(unsigned n, const Dim &d, const ParameterInit &init)¶ Add lookup parameter with custom initializer.
- Return
- LookupParameter object to be used in the computation graph
- Parameters
n
: Number of lookup indicesd
: Dimension of each embeddinginit
: Custom initializer
-
void
project_weights
(float radius = 1.0f)¶ project weights so their L2 norm = radius
NOTE (Paul) : I am not sure this is doing anything currently. The argument doesn’t seem to be used anywhere... If you need this raise an issue on github
- Parameters
radius
: Target norm
-
void
set_weight_decay_lambda
(float lambda)¶ Set the weight decay coefficient.
- Parameters
lambda
: Weight decay coefficient
-
const std::vector<ParameterStorage *> &
parameters_list
() const¶ Returns list of pointers to ParameterSorages.
You shouldn’t need to use this
- Return
- List of pointers to ParameterSorages
-
const std::vector<LookupParameterStorage *> &
lookup_parameters_list
() const¶ Returns list of pointers to LookupParameterSorages.
You shouldn’t need to use this
- Return
- List of pointers to LookupParameterSorages
-
const std::vector<unsigned> &
updated_parameters_list
() const¶ Returns list of indices of updated params.
- Return
- list of indices of updated params
-
const std::vector<unsigned> &
updated_lookup_parameters_list
() const¶ Returns list of indices of updated lookup params.
- Return
- list of indices of updated lookup params
-
size_t
parameter_count
() const¶ Returns the total number of tunable parameters (i. e. scalars) contained within this model.
That is to say, a 2x2 matrix counts as four parameters.
- Return
- Number of parameters
-
size_t
updated_parameter_count
() const¶ Returns total number of (scalar) parameters updated.
- Return
- number of updated parameters
-
void
set_updated_param
(const Parameter *p, bool status)¶ [brief description]
[long description]
- Parameters
p
: [description]status
: [description]
-
void
set_updated_lookup_param
(const LookupParameter *p, bool status)¶ [brief description]
[long description]
- Parameters
p
: [description]status
: [description]
-
bool
is_updated_param
(const Parameter *p)¶ [brief description]
[long description]
- Return
- [description]
- Parameters
p
: [description]
-
bool
is_updated_lookup_param
(const LookupParameter *p)¶ [brief description]
[long description]
- Return
- [description]
- Parameters
p
: [description]
-
Tensor¶
Tensor objects provide a bridge between C++ data structures and Eigen Tensors for multidimensional data.
Concretely, as an end user you will obtain a tensor object after calling .value()
on an expression. You can then use functions described below to convert these tensors to float
s, arrays of float
s, to save and load the values, etc...
Conversely, when implementing low level nodes (e.g. for new operations), you will need to retrieve Eigen tensors from dynet tensors in order to perform efficient computation.
-
std::ostream &
dynet::
operator<<
(std::ostream &os, const Tensor &t)¶ You can use
cout<<tensor;
for debugging or saving.- Parameters
os
: output streamt
: Tensor
-
real
dynet::
as_scalar
(const Tensor &t)¶ Get a scalar value from an order 0 tensor.
Throws an
runtime_error
exception if the tensor has more than one element.TODO : Change for custom invalid dimension exception maybe?
- Return
- Scalar value
- Parameters
t
: Input tensor
-
std::vector<real>
dynet::
as_vector
(const Tensor &v)¶ Get the array of values in the tensor.
For higher order tensors this returns the flattened value
- Return
- Values
- Parameters
v
: Input tensor
-
std::vector<Eigen::DenseIndex>
dynet::
as_vector
(const IndexTensor &v)¶ Get the array of indices in an index tensor.
For higher order tensors this returns the flattened value
- Return
- Index values
- Parameters
v
: Input index tensor
-
real
dynet::
rand01
()¶ This is a helper function to sample uniformly in \([0,1]\).
- Return
- \(x\sim\mathcal U([0,1])\)
-
int
dynet::
rand0n
(int n)¶ This is a helper function to sample uniformly in \(\{0,\dots,n-1\}\).
- Return
- \(x\sim\mathcal U(\{0,\dots,n-1\})\)
- Parameters
n
: Upper bound (excluded)
-
real
dynet::
rand_normal
()¶ This is a helper function to sample from a normalized gaussian distribution.
- Return
- \(x\sim\mathcal N(0,1)\)
-
struct
dynet::
Tensor
¶ - #include <tensor.h>
Represents a tensor of any order.
This provides a bridge between classic C++ types and Eigen tensors.
Public Functions
-
Tensor
()¶ Create an empty tensor.
-
Tensor
(const Dim &d, float *v, Device *dev, DeviceMempool mem)¶ Creates a tensor.
[long description]
- Parameters
d
: Shape of the tensorv
: Pointer to the valuesdev
: Devicemem
: Memory pool
-
Eigen::Map<Eigen::MatrixXf>
operator*
()¶ Get the data as an Eigen matrix.
- Return
- Eigen matrix
-
Eigen::Map<Eigen::VectorXf>
vec
()¶ Get the data as an Eigen vector.
This returns the full tensor contents even if it has many dimensions
- Return
- Flattened tensor
-
Eigen::TensorMap<Eigen::Tensor<float, 1>>
tvec
()¶ Get the data as an order 1 Eigen tensor.
this returns the full tensor contents as a one dimensional Eigen tensor which can be used for on-device processing where dimensions aren’t important
- Return
- Eigen order 1 tensor
-
Eigen::TensorMap<Eigen::Tensor<float, 2>>
tbvec
()¶ Get the data as an order 2 tensor including batch size.
this returns the full tensor contents as a two dimensional Eigen tensor where the first dimension is a flattened representation of each batch and the second dimension is the batches
- Return
- batch size x elements per batch matrix
- template <int Order>
-
Eigen::TensorMap<Eigen::Tensor<float, Order + 1>>
tb
()¶ Get view as an Eigen Tensor where the final dimension is the various batches.
-
float *
batch_ptr
(unsigned bid)¶ Get the pointer for a particular batch.
Automatically broadcasting if the size is zero
- Return
- Pointer to the memory where the batch values are located
- Parameters
bid
: Batch id requested
-
Eigen::Map<Eigen::MatrixXf>
batch_matrix
(unsigned bid)¶ Get the matrix for a particular batch.
Automatically broadcasting if the size is zero.
- Return
- Matrix at batch id
bid
(of shaped.rows()
xd.cols()
) - Parameters
bid
: Batch id requested
-
Eigen::Map<Eigen::MatrixXf>
rowcol_matrix
()¶ Get the data as a matrix, where each “row” is the concatenation of rows and columns, and each “column” is batches.
- Return
- matrix of shape
d.rows() * d.cols()
xd.batch_elems()
-
Eigen::Map<Eigen::MatrixXf>
colbatch_matrix
()¶ Get the data as a matrix, where each “row” is the concatenation of rows, and each “column” is the concatenation of columns and batches.
- Return
- matrix of shape
d.rows() * d.cols()
xd.batch_elems()
-
bool
is_valid
() const¶ Check for NaNs and infinite values.
This is very slow: use sparingly (it’s linear in the number of elements). This raises a
std::runtime_error
exception if the Tensor is on GPU because it’s not implemented yet- Return
- Whether the tensor contains any invalid value
-
Tensor
batch_elem
(unsigned b) const¶ Get a Tensor object representing a single batch.
If this tensor only has a single batch, then broadcast. Otherwise, check to make sure that the requested batch is smaller than the number of batches.
TODO: This is a bit wasteful, as it re-calculates
bs.batch_size()
every time.- Return
- Sub tensor at batch
b
- Parameters
b
: Batch id
-
-
struct
dynet::
IndexTensor
¶ - #include <tensor.h>
Represents a tensor of indices.
This holds indices to locations within a dimension or tensor.
Public Functions
-
IndexTensor
()¶ Create an empty tensor.
-
IndexTensor
(const Dim &d, Eigen::DenseIndex *v, Device *dev, DeviceMempool mem)¶ Creates a tensor.
[long description]
- Parameters
d
: Shape of the tensorv
: Pointer to the valuesdev
: Devicemem
: Memory pool
- template <int Order>
-
Eigen::TensorMap<Eigen::Tensor<Eigen::DenseIndex, Order>>
t
()¶ Get view as a Tensor.
-
-
struct
dynet::
TensorTools
¶ - #include <tensor.h>
Provides tools for creating, accessing, copying and modifying tensors (in-place)
Public Static Functions
-
void
clip
(Tensor &d, float left, float right)¶ Clip the values in the tensor to a fixed range.
- Parameters
d
: Tensor to modifyleft
: Target minimum valueright
: Target maximum value
-
void
constant
(Tensor &d, float c)¶ Fills the tensor with a constant value.
- Parameters
d
: Tensor to modifyc
: Target value
-
void
identity
(Tensor &val)¶ Set the (order 2) tensor as the identity matrix.
this throws a runtime_error exception if the tensor isn’t a square matrix
- Parameters
val
: Input tensor
-
void
randomize_bernoulli
(Tensor &val, real p, real scale = 1.0f)¶ Fill the tensor with bernoulli random variables and scale them by scale.
- Parameters
val
: Input tensorp
: Parameter of the bernoulli distributionscale
: Scale of the random variables
-
void
randomize_normal
(Tensor &val, real mean = 0.0f, real stddev = 1.0f)¶ Fill the tensor with gaussian random variables.
- Parameters
val
: Input tensormean
: Meanstddev
: Standard deviation
-
void
randomize_uniform
(Tensor &val, real left = 0.0f, real right = 0.0f)¶ Fill the tensor with uniform random variables.
- Parameters
val
: Input tensorleft
: Left bound of the intervalright
: Right bound of the interval
-
void
randomize_orthonormal
(Tensor &val, real scale = 1.0f)¶ Takes a square matrix tensor and sets it as a random orthonormal matrix.
More specifically this samples a random matrix with RandomizeUniform and then performs SVD and returns the left orthonormal matrix in the decomposition, scaled by
scale
- Parameters
val
: Input tensorscale
: Value to which the resulting orthonormal matrix will be scaled
-
float
access_element
(const Tensor &v, int index)¶ Access element of the tensor by index in the values array.
AccessElement and SetElement are very, very slow (potentially) - use appropriately
- Return
v.v[index]
- Parameters
v
: Tensorindex
: Index in the memory
-
float
access_element
(const Tensor &v, const Dim &index)¶ Access element of the tensor by indices in the various dimension.
This only works for matrix shaped tensors (+ batch dimension). AccessElement and SetElement are very, very slow (potentially) - use appropriately
- Return
(*v)(index[0], index[1])
- Parameters
v
: Tensorindex
: Indices in the tensor
-
void
set_element
(const Tensor &v, int index, float value)¶ Set element of the tensor by index in the values array.
AccessElement and SetElement are very, very slow (potentially) - use appropriately
- Parameters
v
: Tensorindex
: Index in the memoryvalue
: Desired value
-
void
copy_element
(const Tensor &l, int lindex, Tensor &r, int rindex)¶ Copy element from one tensor to another (by index in the values array)
- Parameters
l
: Source tensorlindex
: Source indexr
: Target tensorrindex
: Target index
-
void
set_elements
(const Tensor &v, const std::vector<float> &vec)¶ Set the elements of a tensor with an array of values.
(This uses memcpy so be careful)
- Parameters
v
: Input Tensorvec
: Values
-
void
copy_elements
(const Tensor &v, const Tensor &v_src)¶ Copy one tensor into another.
- Parameters
v
: Target tensorv_src
: Source tensor
-
IndexTensor
argmax
(const Tensor &v, unsigned dim = 0, unsigned num = 1)¶ Calculate the index of the maximum value.
- Return
- A newly allocated LongTensor consisting of argmax IDs. The length of the dimension “dim” will be “num”, consisting of the appropriate IDs.
- Parameters
v
: A tensor where each row represents a probability distributiondim
: Which dimension to take the argmax overnum
: The number of kmax values
-
IndexTensor
categorical_sample_log_prob
(const Tensor &v, unsigned dim = 0, unsigned num = 1)¶ Calculate samples from a log probability.
- Return
- A newly allocated LongTensor consisting of argmax IDs. The length of the dimension “dim” will be “num”, consisting of the appropriate IDs.
- Parameters
v
: A tensor where each row represents a log probability distributiondim
: Which dimension to take the sample overnum
: The number of samples for each row
-
void
Dimensions¶
The Dim class holds information on the shape of a tensor. As explained in Unorthodox Design, in DyNet the dimensions are represented as the standard dimension + the batch dimension, which makes batched computation transparent.
-
DYNET_MAX_TENSOR_DIM
7¶ Maximum number of dimensions supported by dynet : 7
-
struct
dynet::
Dim
¶ - #include <dim.h>
The Dim struct stores information about the dimensionality of expressions.
Batch dimension is treated separately from standard dimension.
Public Functions
-
Dim
()¶ Default constructor.
-
Dim
(std::initializer_list<unsigned int> x)¶ Initialize from a list of dimensions.
The batch dimension is 1 in this case (non-batched expression)
- Parameters
x
: List of dimentions
-
Dim
(std::initializer_list<unsigned int> x, unsigned int b)¶ Initialize from a list of dimensions and a batch size.
- Parameters
x
: List of dimentionsb
: Batch size
-
Dim
(const std::vector<long> &x)¶ Initialize from a vector of dimensions.
The batch dimension is 1 in this case (non-batched expression)
- Parameters
x
: Array of dimentions
-
Dim
(const std::vector<long> &x, unsigned int b)¶ Initialize from a vector of dimensions and a batch size.
- Parameters
x
: Vector of dimentionsb
: Batch size
-
unsigned int
size
() const¶ Total size of a batch.
- Return
- Batch size * size of a batch
-
unsigned int
batch_size
() const¶ Size of a batch (product of all dimensions)
- Return
- Size of a batch
-
unsigned int
sum_dims
() const¶ Sum of all dimensions within a batch.
- Return
- Sum of the dimensions within a batch
-
void
resize
(unsigned int i)¶ Change the number of dimensions.
- Parameters
int
: New number of dimensions
-
unsigned int
ndims
() const¶ Get number of dimensions.
- Return
- Number of dimensions
-
unsigned int
rows
() const¶ Size of the first dimension.
- Return
- Size of the first dimension
-
unsigned int
num_nonone_dims
() const¶ Number of non-one dimensions.
- Return
- Number of non-one dimensions
-
unsigned int
cols
() const¶ Size of the second dimension (or 1 if only one dimension)
- Return
- Size of the second dimension (or 1 if only one dimension)
-
unsigned int
batch_elems
() const¶ Batch dimension.
- Return
- Batch dimension
-
void
set
(unsigned int i, unsigned int s)¶ Set specific dimension.
Set the value of a specific dimension to an arbitrary value
- Parameters
i
: Dimension indexs
: Dimension size
-
unsigned int
operator[]
(unsigned int i) const¶ Access a specific dimension as you would access an array element.
- Return
- Size of dimension i
- Parameters
i
: Dimension index
-
unsigned int
size
(unsigned int i) const¶ Size of dimension i.
- Return
- Size of dimension i
- Parameters
i
: Dimension index
-
void
delete_dim
(unsigned int i)¶ Remove one of the dimensions.
- Parameters
i
: index of the dimension to be removed
-
void
delete_dims
(std::vector<unsigned int> dims, bool reduce_batch)¶ Remove multi-dimensions.
- Parameters
dims
: dimensions to be removedreduce_batch
: reduce the batch dimension or not
-
void
insert_dim
(unsigned int i, unsigned int n)¶ Insert a dimension.
- Parameters
i
: the index before which to insert the new dimensionn
: the size of the new dimension
-
Operations¶
Operation Interface¶
The following functions define DyNet “Expressions,” which are used as an interface to the various functions that can be used to build DyNet computation graphs. Expressions for each specific function are listed below.
-
struct
dynet::expr::
Expression
¶ - #include <expr.h>
Expressions are the building block of a Dynet computation graph.
[long description]
Public Functions
-
Expression
(ComputationGraph *pg, VariableIndex i)¶ Base expression constructor.
Used when creating operations
- Parameters
pg
: Pointer to the computation graphi
: Variable indexname
: Name of the expression
-
const Tensor &
value
() const¶ Get value of the expression.
Throws a tuntime_error exception if no computation graph is available
- Return
- Value of the expression as a tensor
-
const Tensor &
gradient
() const¶ Get gradient of the expression.
Throws a tuntime_error exception if no computation graph is available
Make sure to call
backward
on a downstream expression before calling this.If the expression is a constant expression (meaning it’s not a function of a parameter), dynet won’t compute it’s gradient for the sake of efficiency. You need to manually force the gradient computation by adding the agument
full=true
tobackward
- Return
- Value of the expression as a tensor
-
Input Operations¶
These operations allow you to input something into the computation graph, either simple scalar/vector/matrix inputs from floats, or parameter inputs from a DyNet parameter object. They all requre passing a computation graph as input so you know which graph is being used for this particular calculation.
-
Expression
dynet::expr::
input
(ComputationGraph &g, real s)¶ Scalar input.
Create an expression that represents the scalar value s
- Return
- An expression representing s
- Parameters
g
: Computation graphs
: Real number
-
Expression
dynet::expr::
input
(ComputationGraph &g, const real *ps)¶ Modifiable scalar input.
Create an expression that represents the scalar value *ps. If *ps is changed and the computation graph recalculated, the next forward pass will reflect the new value.
- Return
- An expression representing *ps
- Parameters
g
: Computation graphps
: Real number pointer
-
Expression
dynet::expr::
input
(ComputationGraph &g, const Dim &d, const std::vector<float> &data)¶ Vector/matrix/tensor input.
Create an expression that represents a vector, matrix, or tensor input. The dimensions of the input are defined by
d
. So for example >input(g,{50},data)
: will result in a 50-length vector >input(g,{50,30},data)
: will result in a 50x30 matrix and so on, for an arbitrary number of dimensions. This function can also be used to import minibatched inputs. For example, if we have 10 examples in a minibatch, each with size 50x30, then we call >input(g,Dim({50,30},10),data)
The data vector “data” will contain the values used to fill the input, in column-major format. The length must add to the product of all dimensions in d.- Return
- An expression representing data
- Parameters
g
: Computation graphd
: Dimension of the input matrixdata
: A vector of data points
-
Expression
dynet::expr::
input
(ComputationGraph &g, const Dim &d, const std::vector<float> *pdata)¶ Updatable vector/matrix/tensor input.
Similarly to input that takes a vector reference, input a vector, matrix, or tensor input. Because we pass the pointer, the data can be updated.
- Return
- An expression representing *pdata
- Parameters
g
: Computation graphd
: Dimension of the input matrixpdata
: A pointer to an (updatable) vector of data points
-
Expression
dynet::expr::
input
(ComputationGraph &g, const Dim &d, const std::vector<unsigned int> &ids, const std::vector<float> &data, float defdata = 0.f)¶ Sparse vector input.
This operation takes input as a sparse matrix of index/value pairs. It is exactly the same as the standard input via vector reference, but sets all non-specified values to “defdata” and resets all others to the appropriate input values.
- Return
- An expression representing data
- Parameters
g
: Computation graphd
: Dimension of the input matrixids
: The indexes of the data points to updatedata
: The data points corresponding to each indexdefdata
: The default data with which to set the unspecified data points
-
Expression
dynet::expr::
parameter
(ComputationGraph &g, Parameter p)¶ Load parameter.
Load parameters into the computation graph.
- Return
- An expression representing p
- Parameters
g
: Computation graphp
: Parameter object to load
-
Expression
dynet::expr::
parameter
(ComputationGraph &g, LookupParameter lp)¶ Load lookup parameter.
Load a full tensor of lookup parameters into the computation graph. Normally lookup parameters are accessed by using the lookup() function to grab a single element. However, in some cases we’ll want to access all of the parameters in the entire set of lookup parameters for some reason. In this case you can use this function. In this case, the first dimensions in the returned tensor will be equivalent to the dimensions that we would get if we get calling the lookup() function, and the size of the final dimension will be equal to the size of the vocabulary.
- Return
- An expression representing lp
- Parameters
g
: Computation graphlp
: LookupParameter object to load
-
Expression
dynet::expr::
const_parameter
(ComputationGraph &g, Parameter p)¶ Load constant parameters.
Load parameters into the computation graph, but prevent them from being updated when performing parameter update.
- Return
- An expression representing the constant p
- Parameters
g
: Computation graphp
: Parameter object to load
-
Expression
dynet::expr::
const_parameter
(ComputationGraph &g, LookupParameter lp)¶ Load constant lookup parameters.
Load lookup parameters into the computation graph, but prevent them from being updated when performing parameter update.
- Return
- An expression representing the constant lp
- Parameters
g
: Computation graphlp
: LookupParameter object to load
-
Expression
dynet::expr::
lookup
(ComputationGraph &g, LookupParameter p, unsigned index)¶ Look up parameter.
Look up parameters according to an index, and load them into the computation graph.
- Return
- An expression representing p[index]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadindex
: Index of the parameters within p
-
Expression
dynet::expr::
lookup
(ComputationGraph &g, LookupParameter p, const unsigned *pindex)¶ Look up parameters with modifiable index.
Look up parameters according to the *pindex, and load them into the computation graph. When *pindex changes, on the next computation of forward() the values will change.
- Return
- An expression representing p[*pindex]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadpindex
: Pointer index of the parameters within p
-
Expression
dynet::expr::
const_lookup
(ComputationGraph &g, LookupParameter p, unsigned index)¶ Look up parameter.
Look up parameters according to an index, and load them into the computation graph. Do not perform gradient update on the parameters.
- Return
- A constant expression representing p[index]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadindex
: Index of the parameters within p
-
Expression
dynet::expr::
const_lookup
(ComputationGraph &g, LookupParameter p, const unsigned *pindex)¶ Constant lookup parameters with modifiable index.
Look up parameters according to the *pindex, and load them into the computation graph. When *pindex changes, on the next computation of forward() the values will change. However, gradient updates will not be performend.
- Return
- A constant expression representing p[*pindex]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadpindex
: Pointer index of the parameters within p
-
Expression
dynet::expr::
lookup
(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> &indices)¶ Look up parameters.
The mini-batched version of lookup. The resulting expression will be a mini-batch of parameters, where the “i”th element of the batch corresponds to the parameters at the position specified by the “i”th element of “indices”
- Return
- An expression with the “i”th batch element representing p[indices[i]]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadindices
: Index of the parameters at each position in the batch
-
Expression
dynet::expr::
lookup
(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> *pindices)¶ Look up parameters.
The mini-batched version of lookup with modifiable parameter indices.
- Return
- An expression with the “i”th batch element representing p[*pindices[i]]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadpindices
: Pointer to lookup indices
-
Expression
dynet::expr::
const_lookup
(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> &indices)¶ Look up parameters.
Mini-batched lookup that will not update the parameters.
- Return
- A constant expression with the “i”th batch element representing p[indices[i]]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadindices
: Lookup indices
-
Expression
dynet::expr::
const_lookup
(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> *pindices)¶ Look up parameters.
Mini-batched lookup that will not update the parameters, with modifiable indices.
- Return
- A constant expression with the “i”th batch element representing p[*pindices[i]]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadpindices
: Lookup index pointers.
-
Expression
dynet::expr::
zeroes
(ComputationGraph &g, const Dim &d)¶ Create an input full of zeros.
Create an input full of zeros, sized according to dimensions d.
- Return
- A “d” dimensioned zero vector
- Parameters
g
: Computation graphd
: The dimensions of the input
-
Expression
dynet::expr::
random_normal
(ComputationGraph &g, const Dim &d)¶ Create a random normal vector.
Create a vector distributed according to normal distribution with mean 0, variance 1.
- Return
- A “d” dimensioned normally distributed vector
- Parameters
g
: Computation graphd
: The dimensions of the input
-
Expression
dynet::expr::
random_bernoulli
(ComputationGraph &g, const Dim &d, real p, real scale = 1.0f)¶ Create a random bernoulli vector.
Create a vector distributed according to bernoulli distribution with parameter p.
- Return
- A “d” dimensioned bernoulli distributed vector
- Parameters
g
: Computation graphd
: The dimensions of the inputp
: The bernoulli p parameterscale
: A scaling factor for the output (“active” elements will receive this value)
-
Expression
dynet::expr::
random_uniform
(ComputationGraph &g, const Dim &d, real left, real right)¶ Create a random uniform vector.
Create a vector distributed according to uniform distribution with boundaries left and right.
- Return
- A “d” dimensioned uniform distributed vector
- Parameters
g
: Computation graphd
: The dimensions of the inputleft
: The left boundaryright
: The right boundary
-
Expression
dynet::expr::
random_gumbel
(ComputationGraph &g, const Dim &d, real mu = 0.0, real beta = 1.0)¶ Create a random Gumbel sampled vector.
Create a vector distributed according to a Gumbel distribution with the specified parameters. (Currently only the defaults of mu=0.0 and beta=1.0 supported.
- Return
- A “d” dimensioned Gumbel distributed vector
- Parameters
g
: Computation graphd
: The dimensions of the inputmu
: The mu parameterbeta
: The beta parameter
Arithmetic Operations¶
These operations perform basic arithemetic over values in the graph.
-
Expression
dynet::expr::
operator-
(const Expression &x)¶ Negation.
Negate the passed argument.
- Return
- The negation of x
- Parameters
x
: An input expression
-
Expression
dynet::expr::
operator+
(const Expression &x, const Expression &y)¶ Expression addition.
Add two expressions of the same dimensions.
- Return
- The sum of x and y
- Parameters
x
: The first inputy
: The second input
-
Expression
dynet::expr::
operator+
(const Expression &x, real y)¶ Scalar addition.
Add a scalar to an expression
- Return
- An expression equal to x, with every component increased by y
- Parameters
x
: The expressiony
: The scalar
-
Expression
dynet::expr::
operator+
(real x, const Expression &y)¶ Scalar addition.
Add a scalar to an expression
- Return
- An expression equal to y, with every component increased by x
- Parameters
x
: The scalary
: The expression
-
Expression
dynet::expr::
operator-
(const Expression &x, const Expression &y)¶ Expression subtraction.
Subtract one expression from another.
- Return
- An expression where the ith element is x_i minus y_i
- Parameters
x
: The expression from which to subtracty
: The expression to subtract
-
Expression
dynet::expr::
operator-
(real x, const Expression &y)¶ Scalar subtraction.
Subtract an expression from a scalar
- Return
- An expression where the ith element is x_i minus y
- Parameters
x
: The scalar from which to subtracty
: The expression to subtract
-
Expression
dynet::expr::
operator-
(const Expression &x, real y)¶ Scalar subtraction.
Subtract a scalar from an expression
- Return
- An expression where the ith element is x_i minus y
- Parameters
x
: The expression from which to subtracty
: The scalar to subtract
-
Expression
dynet::expr::
operator*
(const Expression &x, const Expression &y)¶ Matrix multiplication.
Multiply two matrices together. Like standard matrix multiplication, the second dimension of x and the first dimension of y must match.
- Return
- An expression x times y
- Parameters
x
: The left-hand matrixy
: The right-hand matrix
-
Expression
dynet::expr::
operator*
(const Expression &x, float y)¶ Matrix-scalar multiplication.
Multiply an expression component-wise by a scalar.
- Return
- An expression where the ith element is x_i times y
- Parameters
x
: The matrixy
: The scalar
-
Expression
dynet::expr::
operator*
(float y, const Expression &x)¶ Matrix-scalar multiplication.
Multiply an expression component-wise by a scalar.
- Return
- An expression where the ith element is x_i times y
- Parameters
x
: The scalary
: The matrix
-
Expression
dynet::expr::
operator/
(const Expression &x, float y)¶ Matrix-scalar division.
Divide an expression component-wise by a scalar.
- Return
- An expression where the ith element is x_i divided by y
- Parameters
x
: The matrixy
: The scalar
-
Expression
dynet::expr::
affine_transform
(const std::initializer_list<Expression> &xs)¶ Affine transform.
This performs an affine transform over an arbitrary (odd) number of expressions held in the input initializer list xs. The first expression is the “bias,” which is added to the expression as-is. The remaining expressions are multiplied together in pairs, then added. A very common usage case is the calculation of the score for a neural network layer (e.g. b + Wz) where b is the bias, W is the weight matrix, and z is the input. In this case xs[0] = b, xs[1] = W, and xs[2] = z.
- Return
- An expression equal to: xs[0] + xs[1]*xs[2] + xs[3]*xs[4] + ...
- Parameters
xs
: An initializer list containing an odd number of expressions
-
Expression
dynet::expr::
sum
(const std::initializer_list<Expression> &xs)¶ Sum.
This performs an elementwise sum over all the expressions in xs
- Return
- An expression where the ith element is equal to xs[0][i] + xs[1][i] + ...
- Parameters
xs
: An initializer list containing expressions
-
Expression
dynet::expr::
sum_elems
(const Expression &x)¶ Sum all elements.
Sum all the elements in an expression.
- Return
- The sum of all of its elements
- Parameters
x
: The input expression
-
Expression
dynet::expr::
average
(const std::initializer_list<Expression> &xs)¶ Average.
This performs an elementwise average over all the expressions in xs
- Return
- An expression where the ith element is equal to (xs[0][i] + xs[1][i] + ...)/|xs|
- Parameters
xs
: An initializer list containing expressions
-
Expression
dynet::expr::
sqrt
(const Expression &x)¶ Square root.
Elementwise square root.
- Return
- An expression where the ith element is equal to \(\sqrt(x_i)\)
- Parameters
x
: The input expression
-
Expression
dynet::expr::
abs
(const Expression &x)¶ Absolute value.
Elementwise absolute value.
- Return
- An expression where the ith element is equal to \(\vert x_i\vert\)
- Parameters
x
: The input expression
-
Expression
dynet::expr::
erf
(const Expression &x)¶ Gaussian error function.
Elementwise calculation of the Gaussian error function
- Return
- An expression where the ith element is equal to erf(x_i)
- Parameters
x
: The input expression
-
Expression
dynet::expr::
tanh
(const Expression &x)¶ Hyperbolic tangent.
Elementwise calculation of the hyperbolic tangent
- Return
- An expression where the ith element is equal to tanh(x_i)
- Parameters
x
: The input expression
-
Expression
dynet::expr::
exp
(const Expression &x)¶ Natural exponent.
Calculate elementwise y_i = e^{x_i}
- Return
- An expression where the ith element is equal to e^{x_i}
- Parameters
x
: The input expression
-
Expression
dynet::expr::
square
(const Expression &x)¶ Square.
Calculate elementwise y_i = x_i^2
- Return
- An expression where the ith element is equal to x_i^2
- Parameters
x
: The input expression
-
Expression
dynet::expr::
cube
(const Expression &x)¶ Cube.
Calculate elementwise y_i = x_i^3
- Return
- An expression where the ith element is equal to x_i^3
- Parameters
x
: The input expression
-
Expression
dynet::expr::
lgamma
(const Expression &x)¶ Log gamma.
Calculate elementwise y_i = ln(gamma(x_i))
- Return
- An expression where the ith element is equal to ln(gamma(x_i))
- Parameters
x
: The input expression
-
Expression
dynet::expr::
log
(const Expression &x)¶ Logarithm.
Calculate the elementwise natural logarithm y_i = ln(x_i)
- Return
- An expression where the ith element is equal to ln(x_i)
- Parameters
x
: The input expression
-
Expression
dynet::expr::
logistic
(const Expression &x)¶ Logistic sigmoid function.
Calculate elementwise y_i = 1/(1+e^{-x_i})
- Return
- An expression where the ith element is equal to y_i = 1/(1+e^{-x_i})
- Parameters
x
: The input expression
-
Expression
dynet::expr::
rectify
(const Expression &x)¶ Rectifier.
Calculate elementwise the recitifer (ReLU) function y_i = max(x_i,0)
- Return
- An expression where the ith element is equal to max(x_i,0)
- Parameters
x
: The input expression
-
Expression
dynet::expr::
softsign
(const Expression &x)¶ Soft Sign.
Calculate elementwise the softsign function y_i = x_i/(1+|x_i|)
- Return
- An expression where the ith element is equal to x_i/(1+|x_i|)
- Parameters
x
: The input expression
-
Expression
dynet::expr::
pow
(const Expression &x, const Expression &y)¶ Power function.
Calculate an output where the ith element is equal to x_i^y_i
- Return
- An expression where the ith element is equal to x_i^y_i
- Parameters
x
: The input expressiony
: The exponent expression
-
Expression
dynet::expr::
min
(const Expression &x, const Expression &y)¶ Minimum.
Calculate an output where the ith element is min(x_i,y_i)
- Return
- An expression where the ith element is equal to min(x_i,y_i)
- Parameters
x
: The first input expressiony
: The second input expression
-
Expression
dynet::expr::
max
(const Expression &x, const Expression &y)¶ Maximum.
Calculate an output where the ith element is max(x_i,y_i)
- Return
- An expression where the ith element is equal to max(x_i,y_i)
- Parameters
x
: The first input expressiony
: The second input expression
-
Expression
dynet::expr::
max
(const std::initializer_list<Expression> &xs)¶ Max.
This performs an elementwise max over all the expressions in xs
- Return
- An expression where the ith element is equal to max(xs[0][i], xs[1][i], ...)
- Parameters
xs
: An initializer list containing expressions
-
Expression
dynet::expr::
dot_product
(const Expression &x, const Expression &y)¶ Dot Product.
Calculate the dot product sum_i x_i*y_i
- Return
- An expression equal to the dot product
- Parameters
x
: The input expressiony
: The input expression
-
Expression
dynet::expr::
cmult
(const Expression &x, const Expression &y)¶ Componentwise multiply.
Do a componentwise multiply where each value is equal to x_i*y_i. This function used to be called cwise_multiply.
- Return
- An expression where the ith element is equal to x_i*y_i
- Parameters
x
: The first input expressiony
: The second input expression
-
Expression
dynet::expr::
cdiv
(const Expression &x, const Expression &y)¶ Componentwise multiply.
Do a componentwise multiply where each value is equal to x_i/y_i
- Return
- An expression where the ith element is equal to x_i/y_i
- Parameters
x
: The first input expressiony
: The second input expression
-
Expression
dynet::expr::
colwise_add
(const Expression &x, const Expression &bias)¶ Columnwise addition.
Add vector “bias” to each column of matrix “x”
- Return
- An expression where bias is added to each column of x
- Parameters
x
: An MxN matrixbias
: A length M vector
Probability/Loss Operations¶
These operations are used for calculating probabilities, or calculating loss functions for use in training.
-
Expression
dynet::expr::
softmax
(const Expression &x)¶ Softmax.
The softmax function normalizes each column to ensure that all values are between 0 and 1 and add to one by applying the e^{x[i]}/{sum_j e^{x[j]}}.
- Return
- A vector or matrix after calculating the softmax
- Parameters
x
: A vector or matrix
-
Expression
dynet::expr::
log_softmax
(const Expression &x)¶ Log softmax.
The log softmax function normalizes each column to ensure that all values are between 0 and 1 and add to one by applying the e^{x[i]}/{sum_j e^{x[j]}}, then takes the log
- Return
- A vector or matrix after calculating the log softmax
- Parameters
x
: A vector or matrix
-
Expression
dynet::expr::
log_softmax
(const Expression &x, const std::vector<unsigned> &restriction)¶ Restricted log softmax.
The log softmax function calculated over only a subset of the vector elements. The elements to be included are set by the
restriction
variable. All elements not included inrestriction
are set to negative infinity.- Return
- A vector with the log softmax over the specified elements
- Parameters
x
: A vector over which to calculate the softmaxrestriction
: The elements over which to calculate the softmax
-
Expression
dynet::expr::
logsumexp
(const std::initializer_list<Expression> &xs)¶ Log, sum, exp.
The elementwise “logsumexp” function that calculates \(ln(\sum_i e^{xs_i})\), used in adding probabilities in the log domain.
- Return
- The result.
- Parameters
xs
: Expressions with respect to which to calculate the logsumexp.
-
Expression
dynet::expr::
pickneglogsoftmax
(const Expression &x, unsigned v)¶ Negative softmax log likelihood.
This function takes in a vector of scores
x
, and performs a log softmax, takes the negative, and selects the likelihood corresponding to the elementv
. This is perhaps the most standard loss function for training neural networks to predict one out of a set of elements.- Return
- The negative log likelihood of element
v
after taking the softmax - Parameters
x
: A vector of scoresv
: The element with which to calculate the loss
-
Expression
dynet::expr::
pickneglogsoftmax
(const Expression &x, const unsigned *pv)¶ Modifiable negative softmax log likelihood.
This function calculates the negative log likelihood after the softmax with respect to index
*pv
. This computes the same value as the previous function that passes the indexv
by value, but instead passes by pointer so the value*pv
can be modified without re-constructing the computation graph. This can be used in situations where we want to create a computation graph once, then feed it different data points.- Return
- The negative log likelihood of element
*pv
after taking the softmax - Parameters
x
: A vector of scorespv
: A pointer to the index of the correct element
-
Expression
dynet::expr::
pickneglogsoftmax
(const Expression &x, const std::vector<unsigned> &v)¶ Batched negative softmax log likelihood.
This function is similar to standard pickneglogsoftmax, but calculates loss with respect to multiple batch elements. The input will be a mini-batch of score vectors where the number of batch elements is equal to the number of indices in
v
.- Return
- The negative log likelihoods over all the batch elements
- Parameters
x
: An expression with vectors of scores over N batch elementsv
: A size-N vector indicating the index with respect to all the batch elements
-
Expression
dynet::expr::
pickneglogsoftmax
(const Expression &x, const std::vector<unsigned> *pv)¶ Modifiable batched negative softmax log likelihood.
This function is a combination of modifiable pickneglogsoftmax and batched pickneglogsoftmax:
pv
can be modified without re-creating the computation graph.- Return
- The negative log likelihoods over all the batch elements
- Parameters
x
: An expression with vectors of scores over N batch elementspv
: A pointer to the indexes
-
Expression
dynet::expr::
hinge
(const Expression &x, unsigned index, float m = 1.0)¶ Hinge loss.
This expression calculates the hinge loss, formally expressed as: \( \text{hinge}(x,index,m) = \sum_{i \ne index} \max(0, m-x[index]+x[i]). \)
- Return
- The hinge loss of candidate
index
with respect to marginm
- Parameters
x
: A vector of scoresindex
: The index of the correct candidatem
: The margin
-
Expression
dynet::expr::
hinge
(const Expression &x, const unsigned *pindex, float m = 1.0)¶ Modifiable hinge loss.
This function calculates the hinge loss with with respect to index
*pindex
. This computes the same value as the previous function that passes the indexindex
by value, but instead passes by pointer so the value*pindex
can be modified without re-constructing the computation graph. This can be used in situations where we want to create a computation graph once, then feed it different data points.- Return
- The hinge loss of candidate
*pindex
with respect to marginm
- Parameters
x
: A vector of scorespindex
: A pointer to the index of the correct candidatem
: The margin
-
Expression
dynet::expr::
hinge
(const Expression &x, const std::vector<unsigned> &indices, float m = 1.0)¶ Batched hinge loss.
The same as hinge loss, but for the case where
x
is a mini-batched tensor withindices.size()
batch elements, andindices
is a vector indicating the index of each of the correct elements for these elements.- Return
- The hinge loss of each mini-batch
- Parameters
x
: A mini-batch of vectors withindices.size()
batch elementsindices
: The indices of the correct candidates for each batch elementm
: The margin
-
Expression
dynet::expr::
hinge
(const Expression &x, const std::vector<unsigned> *pindices, float m = 1.0)¶ Batched modifiable hinge loss.
A combination of the previous batched and modifiable hinge loss functions, where vector
*pindices
can be modified.- Return
- The hinge loss of each mini-batch
- Parameters
x
: A mini-batch of vectors withindices.size()
batch elementspindices
: Pointer to the indices of the correct candidates for each batch elementm
: The margin
-
Expression
dynet::expr::
sparsemax
(const Expression &x)¶ Sparsemax.
The sparsemax function (Martins et al. 2016), which is similar to softmax, but induces sparse solutions where most of the vector elements are zero. Note: This function is not yet implemented on GPU.
- Return
- The sparsemax of the scores
- Parameters
x
: A vector of scores
-
Expression
dynet::expr::
sparsemax_loss
(const Expression &x, const std::vector<unsigned> &target_support)¶ Sparsemax loss.
The sparsemax loss function (Martins et al. 2016), which is similar to softmax loss, but induces sparse solutions where most of the vector elements are zero. It has a gradient similar to the sparsemax function and thus is useful for optimizing when the sparsemax will be used at test time. Note: This function is not yet implemented on GPU.
- Return
- The sparsemax loss of the labels
- Parameters
x
: A vector of scorestarget_support
: The target correct labels.
-
Expression
dynet::expr::
sparsemax_loss
(const Expression &x, const std::vector<unsigned> *ptarget_support)¶ Modifiable sparsemax loss.
Similar to the sparsemax loss, but with ptarget_support being a pointer to a vector, allowing it to be modified without re-creating the compuation graph. Note: This function is not yet implemented on GPU.
- Return
- The sparsemax loss of the labels
- Parameters
x
: A vector of scoresptarget_support
: A pointer to the target correct labels.
-
Expression
dynet::expr::
squared_norm
(const Expression &x)¶ Squared norm.
The squared norm of the values of x: \(\sum_i x_i^2\).
- Return
- The squared norm
- Parameters
x
: A vector of values
-
Expression
dynet::expr::
squared_distance
(const Expression &x, const Expression &y)¶ Squared distance.
The squared distance between values of
x
andy
: \(\sum_i (x_i-y_i)^2\).- Return
- The squared distance
- Parameters
x
: A vector of valuesy
: Another vector of values
-
Expression
dynet::expr::
l1_distance
(const Expression &x, const Expression &y)¶ L1 distance.
The L1 distance between values of
x
andy
: \(\sum_i |x_i-y_i|\).- Return
- The squared distance
- Parameters
x
: A vector of valuesy
: Another vector of values
-
Expression
dynet::expr::
huber_distance
(const Expression &x, const Expression &y, float c = 1.345f)¶ Huber distance.
The huber distance between values of
x
andy
parameterized byc,
\(\sum_i L_c(x_i, y_i)\) where:\( L_c(x, y) = \begin{cases}{lr} \frac{1}{2}(y - x)^2 & \textrm{for } |y - f(x)| \le c, \\ c\, |y - f(x)| - \frac{1}{2}c^2 & \textrm{otherwise.} \end{cases} \)
- Return
- The huber distance
- Parameters
x
: A vector of valuesy
: Another vector of valuesc
: The parameter of the huber distance parameterizing the cuttoff
-
Expression
dynet::expr::
binary_log_loss
(const Expression &x, const Expression &y)¶ Binary log loss.
The log loss of a binary decision according to the sigmoid sigmoid function \(- \sum_i (y_i * ln(x_i) + (1-y_i) * ln(1-x_i)) \)
- Return
- The log loss of the sigmoid function
- Parameters
x
: A vector of valuesy
: A vector of true answers
-
Expression
dynet::expr::
pairwise_rank_loss
(const Expression &x, const Expression &y, real m = 1.0)¶ Pairwise rank loss.
A margin-based loss, where every margin violation for each pair of values is penalized: \(\sum_i max(x_i-y_i+m, 0)\)
- Return
- The pairwise rank loss
- Parameters
x
: A vector of valuesy
: A vector of true answersm
: The margin
-
Expression
dynet::expr::
poisson_loss
(const Expression &x, unsigned y)¶ Poisson loss.
The negative log probability of
y
according to a Poisson distribution with parameterx
. Useful in Poisson regression where, we try to predict the parameters of a Possion distribution to maximize the probability of datay
.- Return
- The Poisson loss
- Parameters
x
: The parameter of the Poisson distribution.y
: The target value
-
Expression
dynet::expr::
poisson_loss
(const Expression &x, const unsigned *py)¶ Modifiable Poisson loss.
Similar to Poisson loss, but with the target value passed by pointer so that it can be modified without re-constructing the computation graph.
- Return
- The Poisson loss
- Parameters
x
: The parameter of the Poisson distribution.py
: A pointer to the target value
Flow/Shaping Operations¶
These operations control the flow of information through the graph, or the shape of the vectors/tensors used in the graph.
-
Expression
dynet::expr::
nobackprop
(const Expression &x)¶ Prevent backprop.
This node has no effect on the forward pass, but prevents gradients from flowing backward during the backward pass. This is useful when there’s a subgraph for which you don’t want loss passed back to the parameters.
- Return
- The new expression
- Parameters
x
: The input expression
-
Expression
dynet::expr::
flip_gradient
(const Expression &x)¶ Negative backprop.
This node has no effect on the forward pass, but takes negative on backprop process. This operation is widely used in adversarial networks.
- Return
- An output expression containing the same as input (only effects on backprop process)
- Parameters
x
: The input expression
-
Expression
dynet::expr::
reshape
(const Expression &x, const Dim &d)¶ Reshape to another size.
This node reshapes a tensor to another size, without changing the underlying layout of the data. The layout of the data in DyNet is column-major, so if we have a 3x4 matrix
\( \begin{pmatrix} x_{1,1} & x_{1,2} & x_{1,3} & x_{1,4} \\ x_{2,1} & x_{2,2} & x_{2,3} & x_{2,4} \\ x_{3,1} & x_{3,2} & x_{3,3} & x_{3,4} \\ \end{pmatrix} \)
and transform it into a 2x6 matrix, it will be rearranged as:
\( \begin{pmatrix} x_{1,1} & x_{3,1} & x_{2,2} & x_{1,3} & x_{3,3} & x_{2,4} \\ x_{2,1} & x_{1,2} & x_{3,2} & x_{2,3} & x_{1,4} & x_{3,4} \\ \end{pmatrix} \)
**Note:** This is O(1) for forward, and O(n) for backward.
- Return
- The reshaped expression
- Parameters
x
: The input expressiond
: The new dimensions
-
Expression dynet::expr::transpose(const Expression & x, const std::vector< unsigned > & dims = {1, 0})
Transpose a matrix.
Transpose a matrix or tensor, or if dims is specified shuffle the dimensions arbitrarily. Note: This is O(1) if either the row or column dimension is 1, and O(n) otherwise.
- Return
- The transposed/shuffled expression
- Parameters
x
: The input expressiondims
: The dimensions to swap. The ith dimension of the output will be equal to the dims[i] dimension of the input. dims must have the same number of dimensions as x.
-
Expression
dynet::expr::
select_rows
(const Expression &x, const std::vector<unsigned> &rows)¶ Select rows.
Select a subset of rows of a matrix.
- Return
- An expression containing the selected rows
- Parameters
x
: The input expressionrows
: The rows to extract
-
Expression
dynet::expr::
select_rows
(const Expression &x, const std::vector<unsigned> *prows)¶ Modifiable select rows.
Select a subset of rows of a matrix, where the elements of prows can be modified without re-creating the computation graph.
- Return
- An expression containing the selected rows
- Parameters
x
: The input expressionprows
: The rows to extract
-
Expression
dynet::expr::
select_cols
(const Expression &x, const std::vector<unsigned> &cols)¶ Select columns.
Select a subset of columns of a matrix. select_cols is more efficient than select_rows since DyNet uses column-major order.
- Return
- An expression containing the selected columns
- Parameters
x
: The input expressioncolumns
: The columns to extract
-
Expression
dynet::expr::
select_cols
(const Expression &x, const std::vector<unsigned> *pcols)¶ Modifiable select columns.
Select a subset of columns of a matrix, where the elements of pcols can be modified without re-creating the computation graph.
- Return
- An expression containing the selected columns
- Parameters
x
: The input expressionpcolumns
: The columns to extract
-
Expression
dynet::expr::
sum_batches
(const Expression &x)¶ Sum over minibatches.
Sum an expression that consists of multiple minibatches into one of equal dimension but with only a single minibatch. This is useful for summing loss functions at the end of minibatch training.
- Return
- An expression with a single batch
- Parameters
x
: The input mini-batched expression
-
Expression
dynet::expr::
pick
(const Expression &x, unsigned v, unsigned d = 0)¶ Pick element.
Pick a single element/row/column/sub-tensor from an expression. This will result in the dimension of the tensor being reduced by 1.
- Return
- The value of x[v] along dimension d
- Parameters
x
: The input expressionv
: The index of the element to selectd
: The dimension along which to choose the element
-
Expression
dynet::expr::
pick
(const Expression &x, const std::vector<unsigned> &v, unsigned d = 0)¶ Batched pick.
Pick elements from multiple batches.
- Return
- A mini-batched expression containing the picked elements
- Parameters
x
: The input expressionv
: A vector of indicies to choose, one for each batch in the input expression.d
: The dimension along which to choose the elements
-
Expression
dynet::expr::
pick
(const Expression &x, const unsigned *pv, unsigned d = 0)¶ Modifiable pick element.
Pick a single element from an expression, where the index is passed by pointer so we do not need to re-create the computation graph every time.
- Return
- The value of x[*pv]
- Parameters
x
: The input expressionpv
: Pointer to the index of the element to selectd
: The dimension along which to choose the elements
-
Expression
dynet::expr::
pick
(const Expression &x, const std::vector<unsigned> *pv, unsigned d = 0)¶ Modifiable batched pick element.
Pick multiple elements from an input expression, where the indices are passed by pointer so we do not need to re-create the computation graph every time.
- Return
- A mini-batched expression containing the picked elements
- Parameters
x
: The input expressionpv
: A pointer to vector of indicies to choosed
: The dimension along which to choose the elements
-
Expression
dynet::expr::
pickrange
(const Expression &x, unsigned v, unsigned u)¶ Pick range of elements.
Pick a range of elements from an expression.
- Return
- The value of {x[v],...,x[u]}
- Parameters
x
: The input expressionv
: The beginning indexu
: The end index
-
Expression
dynet::expr::
pick_batch_elem
(const Expression &x, unsigned v)¶ (Modifiable) Pick batch element.
Pick batch element from a batched expression. For a Tensor with 3 batch elements:
\( \begin{pmatrix} x_{1,1,1} & x_{1,1,2} \\ x_{1,2,1} & x_{1,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix} \)
pick_batch_elem(t, 1) will return a Tensor of
\( \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \)
- Return
- The expression of picked batch element. The picked element is a tensor whose
bd
equals to one. - Parameters
x
: The input expressionv
: The index of the batch element to be picked.
-
Expression
dynet::expr::
pick_batch_elems
(const Expression &x, const std::vector<unsigned> &v)¶ (Modifiable) Pick batch elements.
Pick several batch elements from a batched expression. For a Tensor with 3 batch elements:
\( \begin{pmatrix} x_{1,1,1} & x_{1,1,2} \\ x_{1,2,1} & x_{1,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix} \)
pick_batch_elems(t, {2, 3}) will return a Tensor of with 2 batch elements:
\( \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix} \)
- Return
- The expression of picked batch elements. The batch elements is a tensor whose
bd
equals to the size of vectorv
. - Parameters
x
: The input expressionv
: A vector of indicies of the batch elements to be picked.
-
Expression
dynet::expr::
pick_batch_elem
(const Expression &x, const unsigned *v)¶ Pick batch element.
Pick batch element from a batched expression.
- Return
- The expression of picked batch element. The picked element is a tensor whose
bd
equals to one. - Parameters
x
: The input expressionv
: A pointer to the index of the correct element to be picked.
-
Expression
dynet::expr::
pick_batch_elems
(const Expression &x, const std::vector<unsigned> *pv)¶ Pick batch elements.
Pick several batch elements from a batched expression.
- Return
- The expression of picked batch elements. The batch elements is a tensor whose
bd
equals to the size of vectorv
. - Parameters
x
: The input expressionv
: A pointer to the indexes
-
Expression
dynet::expr::
concatenate_to_batch
(const std::initializer_list<Expression> &xs)¶ Concatenate list of expressions to a single batched expression.
Perform a concatenation of several expressions along the batch dimension. All expressions must have the same shape except for the batch dimension.
- Return
- The expression with the batch dimensions concatenated
- Parameters
xs
: The input expressions
-
Expression
dynet::expr::
concatenate_cols
(const std::initializer_list<Expression> &xs)¶ Concatenate columns.
Perform a concatenation of the columns in multiple expressions. All expressions must have the same number of rows.
- Return
- The expression with the columns concatenated
- Parameters
xs
: The input expressions
-
Expression
dynet::expr::
concatenate
(const std::initializer_list<Expression> &xs, unsigned d = 0)¶ Concatenate.
Perform a concatenation of multiple expressions along a particular dimension. All expressions must have the same dimensions except for the dimension to be concatenated (rows by default).
- Return
- The expression with the specified dimension concatenated
- Parameters
xs
: The input expressionsxs
: The dimension along which to perform concatenation
-
Expression
dynet::expr::
max_dim
(const Expression &x, unsigned d = 0)¶ Max out through a dimension.
Select out a element/row/column/sub-tensor from an expression, with maximum value along a given dimension. This will result in the dimension of the tensor being reduced by 1.
- Return
- An expression of sub-tensor with max value along dimension d
- Parameters
x
: The input expressiond
: The dimension along which to choose the element
-
Expression
dynet::expr::
min_dim
(const Expression &x, unsigned d = 0)¶ Min out through a dimension.
Select out a element/row/column/sub-tensor from an expression, with minimum value along a given dimension. This will result in the dimension of the tensor being reduced by 1.
- Return
- An expression of sub-tensor with min value along dimension d
- Parameters
x
: The input expressiond
: The dimension along which to choose the element
Noise Operations¶
These operations are used to add noise to the graph for purposes of making learning more robust.
-
Expression
dynet::expr::
noise
(const Expression &x, real stddev)¶ Gaussian noise.
Add gaussian noise to an expression.
- Return
- The noised expression
- Parameters
x
: The input expressionstddev
: The standard deviation of the gaussian
-
Expression
dynet::expr::
dropout
(const Expression &x, real p)¶ Dropout.
With a fixed probability, drop out (set to zero) nodes in the input expression, and scale the remaining nodes by 1/p. Note that there are two kinds of dropout:
- Regular dropout: where we perform dropout at training time and then scale outputs by p at test time.
- Inverted dropout: where we perform dropout and scaling at training time, and do not need to do anything at test time. DyNet implements the latter, so you only need to apply dropout at training time, and do not need to perform scaling and test time.
- Return
- The dropped out expression
- Parameters
x
: The input expressionp
: The dropout probability
-
Expression
dynet::expr::
block_dropout
(const Expression &x, real p)¶ Block dropout.
Identical to the dropout operation, but either drops out all or no values in the expression, as opposed to making a decision about each value individually.
- Return
- The block dropout expression
- Parameters
x
: The input expressionp
: The block dropout probability
Tensor Operations¶
These operations are used for performing operations on higher order tensors.
-
Expression
dynet::expr::
contract3d_1d
(const Expression &x, const Expression &y)¶ Contracts a rank 3 tensor and a rank 1 tensor into a rank 2 tensor.
The resulting tensor \(z\) has coordinates \(z_ij = \sum_k x_{ijk} y_k\)
- Return
- Matrix
- Parameters
x
: Rank 3 tensory
: Vector
-
Expression
dynet::expr::
contract3d_1d_1d
(const Expression &x, const Expression &y, const Expression &z)¶ Contracts a rank 3 tensor and two rank 1 tensor into a rank 1 tensor.
This is the equivalent of calling
contract3d_1d
and then performing a matrix vector multiplication.The resulting tensor \(t\) has coordinates \(t_i = \sum_{j,k} x_{ijk} y_k z_j\)
- Return
- Vector
- Parameters
x
: Rank 3 tensory
: Vectorz
: Vector
-
Expression
dynet::expr::
contract3d_1d_1d
(const Expression &x, const Expression &y, const Expression &z, const Expression &b)¶ Same as
contract3d_1d_1d
with an additional bias parameter.This is the equivalent of calling
contract3d_1d
and then performing an affine transform.The resulting tensor \(t\) has coordinates \(t_i = b_i + \sum_{j,k} x_{ijk} y_k z_j\)
- Return
- Vector
- Parameters
x
: Rank 3 tensory
: Vectorz
: Vectorb
: Bias vector
-
Expression
dynet::expr::
contract3d_1d
(const Expression &x, const Expression &y, const Expression &b)¶ Same as
contract3d_1d
with an additional bias parameter.The resulting tensor \(z\) has coordinates \(z_{ij} = b_{ij}+\sum_k x_{ijk} y_k\)
- Return
- Matrix
- Parameters
x
: Rank 3 tensory
: Vectorb
: Bias matrix
Linera Algebra Operations¶
These operations are used for performing various operations common in linear algebra.
-
Expression
dynet::expr::
inverse
(const Expression &x)¶ Matrix Inverse.
Takes the inverse of a matrix (not implemented on GPU yet, although contributions are welcome: https://github.com/clab/dynet/issues/158). Note that back-propagating through an inverted matrix can also be the source of stability problems sometimes.
- Return
- The inverse of the matrix
- Parameters
x
: A square matrix
-
Expression
dynet::expr::
logdet
(const Expression &x)¶ Log determinant.
Takes the log of the determinant of a matrix. (not implemented on GPU yet, although contributions are welcome: https://github.com/clab/dynet/issues/158).
- Return
- The log of its determinant
- Parameters
x
: A square matrix
-
Expression
dynet::expr::
trace_of_product
(const Expression &x, const Expression &y)¶ Trace of Matrix Product.
Takes the trace of the product of matrices. (not implemented on GPU yet, although contributions are welcome: https://github.com/clab/dynet/issues/158).
- Return
- trace(x1 * x2)
- Parameters
x1
: A matrixx2
: Another matrix
Convolution Operations¶
These operations are convolution-related.
-
Expression
dynet::expr::
conv2d
(const Expression &x, const Expression &f, const std::vector<unsigned> &stride, bool is_valid = true)¶ conv2d without bias
2D convolution operator without bias parameters. ‘VALID’ and ‘SAME’ convolutions are supported. Think about when stride is 1, the distinction:
- SAME: output size is the same with input size. To do so, one needs to pad the input so the filter can sweep outside of the input maps.
- VALID: output size shrinks by filter_size - 1, and the filters always sweep at valid positions inside the input maps. No padding needed.
In detail, assume:
- Input feature maps: (XH x XW x XC) x N
- Filters: FH x FW x XC x FC, 4D tensor
- Strides: strides[0] and strides[1] are row (h) and col (w) stride, respectively.
For the SAME convolution: the output height (YH) and width (YW) are computed as:
- YH = ceil(float(XH) / float(strides[0]))
- YW = ceil(float(XW) / float(strides[1])) and the paddings are computed as:
- pad_along_height = max((YH - 1) * strides[0] + FH - XH, 0)
- pad_along_width = max((YW - 1) * strides[1] + FW - XW, 0)
- pad_top = pad_along_height / 2
- pad_bottom = pad_along_height - pad_top
- pad_left = pad_along_width / 2
- pad_right = pad_along_width - pad_left
For the VALID convolution: the output height (YH) and width (YW) are computed as:
- YH = ceil(float(XH - FH + 1) / float(strides[0]))
- YW = ceil(float(XW - FW + 1) / float(strides[1])) and the paddings are always zeros.
- Return
- The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
- Parameters
x
: The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimensionf
: 2D convolution filters: H x W x Ci x Co (ColMaj), 4D tensorstride
: the row and column stridesis_valid
: ‘VALID’ convolution or ‘SAME’ convolution, default is True (‘VALID’)
-
Expression
dynet::expr::
conv2d
(const Expression &x, const Expression &f, const Expression &b, const std::vector<unsigned> &stride, bool is_valid = true)¶ conv2d with bias
2D convolution operator with bias parameters. ‘VALID’ and ‘SAME’ convolutions are supported. Think about when stride is 1, the distinction:
- SAME: output size is the same with input size. To do so, one needs to pad the input so the filter can sweep outside of the input maps.
- VALID: output size shrinks by filter_size - 1, and the filters always sweep at valid positions inside the input maps. No padding needed.
In detail, assume:
- Input feature maps: XH x XW x XC x N
- Filters: FH x FW x XC x FC
- Strides: strides[0] and strides[1] are row (h) and col (w) stride, respectively.
For the SAME convolution: the output height (YH) and width (YW) are computed as:
- YH = ceil(float(XH) / float(strides[0]))
- YW = ceil(float(XW) / float(strides[1])) and the paddings are computed as:
- pad_along_height = max((YH - 1) * strides[0] + FH - XH, 0)
- pad_along_width = max((YW - 1) * strides[1] + FW - XW, 0)
- pad_top = pad_along_height / 2
- pad_bottom = pad_along_height - pad_top
- pad_left = pad_along_width / 2
- pad_right = pad_along_width - pad_left
For the VALID convolution: the output height (YH) and width (YW) are computed as:
- YH = ceil(float(XH - FH + 1) / float(strides[0]))
- YW = ceil(float(XW - FW + 1) / float(strides[1])) and the paddings are always zeros.
- Return
- The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
- Parameters
x
: The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimensionf
: 2D convolution filters: H x W x Ci x Co (ColMaj), 4D tensorb
: The bias (1D: Ci)stride
: the row and column stridesis_valid
: ‘VALID’ convolution or ‘SAME’ convolution, default is True (‘VALID’)
Normalization Operations¶
This includes batch normalization and the likes.
-
Expression
dynet::expr::
layer_norm
(const Expression &x, const Expression &g, const Expression &b)¶ Layer normalization.
Performs layer normalization :
\( \begin{split} \mu &= \frac 1 n \sum_{i=1}^n x_i\\ \sigma &= \sqrt{\frac 1 n \sum_{i=1}^n (x_i-\mu)^2}\\ y&=\frac {\boldsymbol{g}} \sigma \circ (\boldsymbol{x}-\mu) + \boldsymbol{b}\\ \end{split} \)
Reference : Ba et al., 2016
- Return
- An expression of the same dimension as
x
- Parameters
x
: Input expression (possibly batched)g
: Gain (same dimension as x, no batch dimension)b
: Bias (same dimension as x, no batch dimension)
Builders¶
Builders combine together various operations to implement more complicated things such as recurrent and LSTM networks
-
struct
dynet::
LSTMBuilder
¶ - #include <lstm.h>
LSTMBuilder creates an LSTM unit with coupled input and forget gate as well as peepholes connections.
More specifically, here are the equations for the dynamics of this cell :
\( \begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+W_{ic}c_{t-1}+b_i)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ (1-i_t) + \tilde{c_t}\circ i_t\\ & = c_{t-1} + (\tilde{c_t}-c_{t-1})\circ i_t\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+W_{oc}c_{t}+b_o)\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)
Inherits from dynet::RNNBuilder
Public Functions
-
LSTMBuilder
()¶ Default constructor.
-
LSTMBuilder
(unsigned layers, unsigned input_dim, unsigned hidden_dim, Model &model)¶ Constructor for the LSTMBuilder.
- Parameters
layers
: Number of layersinput_dim
: Dimention of the input \(x_t\)hidden_dim
: Dimention of the hidden states \(h_t\) and \(c_t\)model
: Model holding the parameters
-
unsigned
num_h0_components
() const¶ Number of components in
h_0
For
LSTMBuilder
, this corresponds to2 * layers
because it includes the initial cell state \(c_0\)- Return
2 * layers
-
std::vector<Expression>
get_s
(RNNPointer i) const¶ Get the final state of the hidden layer.
For
LSTMBuilder
, this consists of a vector of the memory cell values for each layer (l1, l2, l3), followed by the hidden state values- Return
- {c_{l1}, c_{l1}, ..., h_{l1}, h_{l2}, ...}
-
void
set_dropout
(float d)¶ Set the dropout rates to a unique value.
This has the same effect as
set_dropout(d,d_h,d_c)
except that all the dropout rates are set to the same value.- Parameters
d
: Dropout rate to be applied on all of \(x,h,c\)
-
void
set_dropout
(float d, float d_h, float d_c)¶ Set the dropout rates.
The dropout implemented here is an adaptation of the variational dropout with tied weights introduced in Gal, 2016 More specifically, dropout masks \(\mathbf{z_x}\sim \mathrm{Bernoulli}(1-d_x)\), \(\mathbf{z_h}\sim \mathrm{Bernoulli}(1-d_h)\), \(\mathbf{z_c}\sim \mathrm{Bernoulli}(1-d_c)\) are sampled at the start of each sequence. The dynamics of the cell are then modified to :
\( \begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x} {\mathbf{z_x}} \circ x_t)+W_{ih}(\frac 1 {1-d_h} {\mathbf{z_h}} \circ h_{t-1})+W_{ic}(\frac 1 {1-d_c} {\mathbf{z_c}} \circ c_{t-1})+b_i)\\ \tilde{c_t} & = \tanh(W_{cx}(\frac 1 {1-d_x} {\mathbf{z_x}} \circ x_t)+W_{ch}(\frac 1 {1-d_h} {\mathbf{z_h}} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ (1-i_t) + \tilde{c_t}\circ i_t\\ & = c_{t-1} + (\tilde{c_t}-c_{t-1})\circ i_t\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x} {\mathbf{z_x}} \circ x_t)+W_{oh}(\frac 1 {1-d_h} {\mathbf{z_h}} \circ h_{t-1})+W_{oc}(\frac 1 {1-d_c} {\mathbf{z_c}} \circ c_{t})+b_o)\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)
For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation
- Parameters
d
: Dropout rate \(d_x\) for the input \(x_t\)d_h
: Dropout rate \(d_x\) for the output \(h_t\)d_c
: Dropout rate \(d_x\) for the cell \(c_t\)
-
void
disable_dropout
()¶ Set all dropout rates to 0.
This is equivalent to
set_dropout(0)
orset_dropout(0,0,0)
-
void
set_dropout_masks
(unsigned batch_size = 1)¶ Set dropout masks at the beginning of a sequence for a specific bathc size.
If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element
- Parameters
batch_size
: Batch size
-
-
struct
dynet::
VanillaLSTMBuilder
¶ - #include <lstm.h>
VanillaLSTM allows to create an “standard” LSTM, ie with decoupled input and forget gate and no peepholes connections.
This cell runs according to the following dynamics :
\( \begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+b_i)\\ f_t & = \sigma(W_{fx}x_t+W_{fh}h_{t-1}+b_f+1)\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)
Inherits from dynet::RNNBuilder
Public Functions
-
VanillaLSTMBuilder
()¶ Default Constructor.
-
VanillaLSTMBuilder
(unsigned layers, unsigned input_dim, unsigned hidden_dim, Model &model, bool ln_lstm = false)¶ Constructor for the VanillaLSTMBuilder.
- Parameters
layers
: Number of layersinput_dim
: Dimention of the input \(x_t\)hidden_dim
: Dimention of the hidden states \(h_t\) and \(c_t\)model
: Model holding the parametersln_lstm
: Whether to use layer normalization
-
void
set_dropout
(float d)¶ Set the dropout rates to a unique value.
This has the same effect as
set_dropout(d,d_h)
except that all the dropout rates are set to the same value.- Parameters
d
: Dropout rate to be applied on all of \(x,h\)
-
void
set_dropout
(float d, float d_r)¶ Set the dropout rates.
The dropout implemented here is the variational dropout with tied weights introduced in Gal, 2016 More specifically, dropout masks \(\mathbf{z_x}\sim \mathrm{Bernoulli}(1-d_x)\), \(\mathbf{z_h}\sim \mathrm{Bernoulli}(1-d_h)\) are sampled at the start of each sequence. The dynamics of the cell are then modified to :
\( \begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ih}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_i)\\ f_t & = \sigma(W_{fx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{fh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_f)\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{oh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ch}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)
For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation
- Parameters
d
: Dropout rate \(d_x\) for the input \(x_t\)d_h
: Dropout rate \(d_h\) for the output \(h_t\)
-
void
disable_dropout
()¶ Set all dropout rates to 0.
This is equivalent to
set_dropout(0)
orset_dropout(0,0,0)
-
void
set_dropout_masks
(unsigned batch_size = 1)¶ Set dropout masks at the beginning of a sequence for a specific batch size.
If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element
- Parameters
batch_size
: Batch size
-
-
struct
dynet::
RNNBuilder
¶ - #include <rnn.h>
interface for constructing an RNN, LSTM, GRU, etc.
[long description]
Subclassed by dynet::DeepLSTMBuilder, dynet::FastLSTMBuilder, dynet::GRUBuilder, dynet::LSTMBuilder, dynet::SimpleRNNBuilder, dynet::TreeLSTMBuilder, dynet::VanillaLSTMBuilder
Public Functions
-
RNNBuilder
()¶ Default constructor.
-
RNNPointer
state
() const¶ Get pointer to the current state.
- Return
- Pointer to the current state
-
void
new_graph
(ComputationGraph &cg)¶ Initialize with new computation graph.
call this to reset the builder when you are working with a newly created ComputationGraph object
- Parameters
cg
: Computation graph
-
void
start_new_sequence
(const std::vector<Expression> &h_0 = {})¶ Reset for new sequence.
call this before add_input and after new_graph, when starting a new sequence on the same hypergraph.
- Parameters
h_0
:h_0
is used to initialize hidden layers at timestep 0 to given values
-
Expression
set_h
(const RNNPointer &prev, const std::vector<Expression> &h_new = {})¶ Explicitly set the output state of a node.
- Return
- The hidden representation of the deepest layer
- Parameters
prev
: Pointer to the previous stateh_new
: The new hidden state
-
Expression
set_s
(const RNNPointer &prev, const std::vector<Expression> &s_new = {})¶ Set the internal state of a node (for lstms/grus)
For RNNs without internal states (SimpleRNN, GRU...), this has the same behaviour as
set_h
- Return
- The hidden representation of the deepest layer
- Parameters
prev
: Pointer to the previous states_new
: The new state. Can be{new_c[0],...,new_c[n]}
or{new_c[0],...,new_c[n], new_h[0],...,new_h[n]}
-
Expression
add_input
(const Expression &x)¶ Add another timestep by reading in the variable x.
- Return
- The hidden representation of the deepest layer
- Parameters
x
: Input variable
-
Expression
add_input
(const RNNPointer &prev, const Expression &x)¶ Add another timestep, with arbitrary recurrent connection.
This allows you to define a recurrent connection to
prev
rather than tohead[cur]
. This can be used to construct trees, implement beam search, etc.- Return
- The hidden representation of the deepest layer
- Parameters
prev
: Pointer to the previous statex
: Input variable
-
void
rewind_one_step
()¶ Rewind the last timestep.
- this DOES NOT remove the variables from the computation graph, it just means the next time step will see a different previous state. You can rewind as many times as you want.
-
RNNPointer
get_head
(const RNNPointer &p)¶ Return the RNN state that is the parent of
p
- This can be used in implementing complex structures such as trees, etc.
-
void
set_dropout
(float d)¶ Set Dropout.
- Parameters
d
: Dropout rate
-
void
disable_dropout
()¶ Disable Dropout.
In general, you should disable dropout at test time
-
virtual Expression
back
() const = 0¶ Returns node (index) of most recent output.
- Return
- Node (index) of most recent output
-
virtual std::vector<Expression>
final_h
() const = 0¶ Access the final output of each hidden layer.
- Return
- Final output of each hidden layer
-
virtual std::vector<Expression>
get_h
(RNNPointer i) const = 0¶ Access the output of any hidden layer.
- Return
- Output of each hidden layer at the given step
- Parameters
i
: Pointer to the step which output you want to access
-
virtual std::vector<Expression>
final_s
() const = 0¶ Access the final state of each hidden layer.
This returns the state of each hidden layer, in a format that can be used in start_new_sequence (i.e. including any internal cell for LSTMs and the likes)
- Return
- vector containing, if it exists, the list of final internal states, followed by the list of final outputs for each layer
-
virtual std::vector<Expression>
get_s
(RNNPointer i) const = 0¶ Access the state of any hidden layer.
See
final_s
for details- Return
- Internal state of each hidden layer at the given step
- Parameters
i
: Pointer to the step which state you want to access
-
virtual unsigned
num_h0_components
() const = 0¶ Number of components in
h_0
- Return
- Number of components in
h_0
-
virtual void
copy
(const RNNBuilder ¶ms) = 0¶ Copy the parameters of another builder.
- Parameters
params
: RNNBuilder you want to copy parameters from.
-
void
save_parameters_pretraining
(const std::string &fname) const¶ This function saves all the parameters associated with a particular RNNBuilder‘s derived class to a file.
This should not be used to seralize models, it should only be used to save parameters for pretraining. If you are interested in serializing models, use the boost serialization API against your model class.
- Parameters
fname
: File you want to save your model to.
-
void
load_parameters_pretraining
(const std::string &fname)¶ Loads all the parameters associated with a particular RNNBuilder‘s derived class from a file.
This should not be used to seralize models, it should only be used to load parameters from pretraining. If you are interested in serializing models, use the boost serialization API against your model class.
- Parameters
fname
: File you want to read your model from.
-
-
struct
dynet::
SimpleRNNBuilder
¶ - #include <rnn.h>
This provides a builder for the simplest RNN with tanh nonlinearity.
The equation for this RNN is : \(h_t=\tanh(W_x x_t + W_h h_{t-1} + b)\)
Inherits from dynet::RNNBuilder
Public Functions
-
SimpleRNNBuilder
(unsigned layers, unsigned input_dim, unsigned hidden_dim, Model &model, bool support_lags = false)¶ Builds a simple RNN.
- Parameters
layers
: Number of layersinput_dim
: Dimension of the inputhidden_dim
: Hidden layer (and output) sizemodel
: Model holding the parameterssupport_lags
: Allow for auxiliary output?
-
Expression
add_auxiliary_input
(const Expression &x, const Expression &aux)¶ Add auxiliary output.
Returns \(h_t=\tanh(W_x x_t + W_h h_{t-1} + W_y y + b)\) where \(y\) is an auxiliary output TODO : clarify
- Return
- The hidden representation of the deepest layer
- Parameters
x
: Input expressionaux
: Auxiliary output expression
-
Optimizers¶
The various optimizers that you can use to tune your parameters
-
struct
dynet::
SimpleSGDTrainer
¶ - #include <training.h>
Stochastic gradient descent trainer.
This trainer performs stochastic gradient descent, the goto optimization procedure for neural networks. In the standard setting, the learning rate at epoch \(t\) is \(\eta_t=\frac{\eta_0}{1+\eta_{\mathrm{decay}}t}\)
Reference : reference needed
Inherits from dynet::Trainer
-
struct
dynet::
CyclicalSGDTrainer
¶ - #include <training.h>
Cyclical learning rate SGD.
This trainer performs stochastic gradient descent with a cyclical learning rate as proposed in Smith, 2015.
This uses a triangular function with optional exponential decay.
More specifically, at each update, the learning rate \(\eta\) is updated according to :
\( \begin{split} \text{cycle} &= \left\lfloor 1 + \frac{\texttt{it}}{2 \times\texttt{step_size}} \right\rfloor\\ x &= \left\vert \frac{\texttt{it}}{\texttt{step_size}} - 2 \times \text{cycle} + 1\right\vert\\ \eta &= \eta_{\text{min}} + (\eta_{\text{max}} - \eta_{\text{min}}) \times \max(0, 1 - x) \times \gamma^{\texttt{it}}\\ \end{split} \)
Reference : Cyclical Learning Rates for Training Neural Networks
Inherits from dynet::Trainer
Public Functions
-
CyclicalSGDTrainer
(Model &m, float e0_min = 0.01, float e0_max = 0.1, float step_size = 2000, float gamma = 0.0, float edecay = 0.0)¶ Constructor.
- Parameters
m
: Model to be trainede0_min
: Lower learning ratee0_max
: Upper learning ratestep_size
: Period of the triangular function in number of iterations (not epochs). According to the original paper, this should be set around (2-8) x (training iterations in epoch)gamma
: Learning rate upper bound decay parameteredecay
: Learning rate decay parameter. Ideally you shouldn’t use this with cyclical learning rate since decay is already handled by \(\gamma\)
-
-
struct
dynet::
MomentumSGDTrainer
¶ - #include <training.h>
Stochastic gradient descent with momentum.
This is a modified version of the SGD algorithm with momentum to stablize the gradient trajectory. The modified gradient is \(\theta_{t+1}=\mu\theta_{t}+\nabla_{t+1}\) where \(\mu\) is the momentum.
Reference : reference needed
Inherits from dynet::Trainer
-
struct
dynet::
AdagradTrainer
¶ - #include <training.h>
Adagrad optimizer.
The adagrad algorithm assigns a different learning rate to each parameter according to the following formula : \(\delta_\theta^{(t)}=-\frac{\eta_0}{\epsilon+\sum_{i=0}^{t-1}(\nabla_\theta^{(i)})^2}\nabla_\theta^{(t)}\)
Reference : Duchi et al., 2011
Inherits from dynet::Trainer
-
struct
dynet::
AdadeltaTrainer
¶ - #include <training.h>
AdaDelta optimizer.
The AdaDelta optimizer is a variant of Adagrad where \(\frac{\eta_0}{\sqrt{\epsilon+\sum_{i=0}^{t-1}(\nabla_\theta^{(i)})^2}}\) is replaced by \(\frac{\sqrt{\epsilon+\sum_{i=0}^{t-1}\rho^{t-i-1}(1-\rho)(\delta_\theta^{(i)})^2}}{\sqrt{\epsilon+\sum_{i=0}^{t-1}(\nabla_\theta^{(i)})^2}}\), hence eliminating the need for an initial learning rate.
Reference : ADADELTA: An Adaptive Learning Rate Method
Inherits from dynet::Trainer
Public Functions
-
struct
dynet::
RMSPropTrainer
¶ - #include <training.h>
RMSProp optimizer.
The RMSProp optimizer is a variant of Adagrad where the squared sum of previous gradients is replaced with a moving average with parameter \(\rho\).
Reference : reference needed
Inherits from dynet::Trainer
Public Functions
-
RMSPropTrainer
(Model &m, real e0 = 0.001, real eps = 1e-08, real rho = 0.9, real edecay = 0.0)¶ Constructor.
- Parameters
m
: Model to be trainede0
: Initial learning rateeps
: Bias parameter \(\epsilon\) in the adagrad formularho
: Update parameter for the moving average (rho = 0
is equivalent to using Adagrad)edecay
: Learning rate decay parameter
-
-
struct
dynet::
AdamTrainer
¶ - #include <training.h>
Adam optimizer.
The Adam optimizer is similar to RMSProp but uses unbiased estimates of the first and second moments of the gradient
Reference : Adam: A Method for Stochastic Optimization
Inherits from dynet::Trainer
Public Functions
-
AdamTrainer
(Model &m, float e0 = 0.001, float beta_1 = 0.9, float beta_2 = 0.999, float eps = 1e-8, real edecay = 0.0)¶ Constructor.
- Parameters
m
: Model to be trainede0
: Initial learning ratebeta_1
: Moving average parameter for the meanbeta_2
: Moving average parameter for the varianceeps
: Bias parameter \(\epsilon\)edecay
: Learning rate decay parameter
-
-
struct
dynet::
Trainer
¶ - #include <training.h>
General trainer struct.
Subclassed by dynet::AdadeltaTrainer, dynet::AdagradTrainer, dynet::AdamTrainer, dynet::CyclicalSGDTrainer, dynet::MomentumSGDTrainer, dynet::RMSPropTrainer, dynet::SimpleSGDTrainer
Public Functions
-
Trainer
(Model &m, real e0, real edecay = 0.0)¶ General constructor for a Trainer.
- Parameters
m
: Model to be trainede0
: Initial learning rateedecay
: Learning rate decay
-
void
update
(real scale = 1.0)¶ Update parameters.
Update the parameters according to the appropriate update rule
- Parameters
scale
: The scaling factor for the gradients
-
void
update
(const std::vector<unsigned> &updated_params, const std::vector<unsigned> &updated_lookup_params, real scale = 1.0)¶ Update subset of parameters.
Update some but not all of the parameters included in the model. This is the update_subset() function in the Python bindings. The parameters to be updated are specified by index, which can be found for Parameter and LookupParameter objects through the “index” variable (or the get_index() function in the Python bindings).
- Parameters
updated_params
: The parameter indices to be updatedupdated_lookup_params
: The lookup parameter indices to be updatedscale
: The scaling factor for the gradients
-
float
clip_gradients
(real scale)¶ Clip gradient.
If clipping is enabled and the gradient is too big, return the amount to scale the gradient by (otherwise 1)
- Return
- The appropriate scaling factor
- Parameters
scale
: The clipping limit
Public Members
-
bool
sparse_updates_enabled
¶ Whether to perform sparse updates.
DyNet trainers support two types of updates for lookup parameters, sparse and dense. Sparse updates are the default. They have the potential to be faster, as they only touch the parameters that have non-zero gradients. However, they may not always be faster (particulary on GPU with mini-batch training), and are not precisely numerically correct for some update rules such as MomentumTrainer and AdamTrainer. Thus, if you set this variable to false, the trainer will perform dense updates and be precisely correct, and maybe faster sometimes.
-
Examples¶
Here are some simple models coded in the examples of Dynet. Feel free to use and modify them.
Feed-forward models¶
Although Dynet was primarily built for natural language processing purposes it is still possible to code feed-forward nets. Here are some bricks and examples to do so.
-
enum
ffbuilders::
Activation
¶ Common activation functions used in multilayer perceptrons
Values:
-
SIGMOID
¶ SIGMOID
: Sigmoid function \(x\longrightarrow \frac {1} {1+e^{-x}}\)
-
TANH
¶ TANH
: Tanh function \(x\longrightarrow \frac {1-e^{-2x}} {1+e^{-2x}}\)
-
RELU
¶ RELU
: Rectified linear unit \(x\longrightarrow \max(0,x)\)
-
LINEAR
¶ LINEAR
: Identity function \(x\longrightarrow x\)
-
SOFTMAX
¶ SOFTMAX
: Softmax function \(\textbf{x}=(x_i)_{i=1,\dots,n}\longrightarrow \frac {e^{x_i}}{\sum_{j=1}^n e^{x_j} })_{i=1,\dots,n}\)
-
-
struct
Layer
¶ - #include <mlp.h>
Simple layer structure.
Contains all parameters defining a layer
Public Functions
-
Layer
(unsigned input_dim, unsigned output_dim, Activation activation, float dropout_rate)¶ Build a feed forward layer.
- Parameters
input_dim
: Input dimensionoutput_dim
: Output dimensionactivation
: Activation functiondropout_rate
: Dropout rate
-
-
struct
MLP
¶ - #include <mlp.h>
Simple multilayer perceptron.
Public Functions
-
MLP
(Model &model)¶ Default constructor.
Dont forget to add layers!
-
MLP
(Model &model, vector<Layer> layers)¶ Returns a Multilayer perceptron.
Creates a feedforward multilayer perceptron based on a list of layer descriptions
- Parameters
model
: Model to contain parameterslayers
: Layers description
-
void
append
(Model &model, Layer layer)¶ Append a layer at the end of the network.
[long description]
- Parameters
model
: [description]layer
: [description]
-
Expression
run
(Expression x, ComputationGraph &cg)¶ Run the MLP on an input vector/batch.
- Return
- [description]
- Parameters
x
: Input expression (vector or batch)cg
: Computation graph
-
Expression
get_nll
(Expression x, vector<unsigned> labels, ComputationGraph &cg)¶ Return the negative log likelihood for the (batched) pair (x,y)
For a batched input \(\{x_i\}_{i=1,\dots,N}\), \(\{y_i\}_{i=1,\dots,N}\), this computes \(\sum_{i=1}^N \log(P(y_i\vert x_i))\) where \(P(\textbf{y}\vert x_i)\) is modelled with ${softmax}(MLP(x_i))$
- Return
- Expression for the negative log likelihood on the batch
- Parameters
x
: Input batchlabels
: Output labelscg
: Computation graph
-
int
predict
(Expression x, ComputationGraph &cg)¶ Predict the most probable label.
Returns the argmax of the softmax of the networks output
- Return
- Label index
- Parameters
x
: Inputcg
: Computation graph
-
void
enable_dropout
()¶ Enable dropout.
This is supposed to be used during training or during testing if you want to sample outputs using montecarlo
-
void
disable_dropout
()¶ Disable dropout.
Do this during testing if you want a deterministic network
-
bool
is_dropout_enabled
()¶ Check wether dropout is enabled or not.
- Return
- Dropout state
-
Language models¶
Language modelling is one of the cornerstones of natural language processing. Dynet allows great flexibility in the creation of neural language models. Here are some examples.
- template <class Builder>
-
struct
RNNBatchLanguageModel
¶ - #include <rnnlm-batch.h>
This structure wraps any RNN to train a language model with minibatching.
Recurrent neural network based language modelling maximizes the likelihood of a sentence \(\textbf s=(w_1,\dots,w_n)\) by modelling it as :
\(L(\textbf s)=p(w_1,\dots,w_n)=\prod_{i=1}^n p(w_i\vert w_1,\dots,w_{i-1})\)
Where \(p(w_i\vert w_1,\dots,w_{i-1})\) is given by the output of the RNN at step \(i\)
In the case of training with minibatching, the sentences must be of the same length in each minibatch. This requires some preprocessing (see
train_rnnlm-batch.cc
for example).Reference : Mikolov et al., 2010
- Template Parameters
Builder
: This can be any RNNBuilder
Public Functions
-
RNNBatchLanguageModel
(Model &model, unsigned LAYERS, unsigned INPUT_DIM, unsigned HIDDEN_DIM, unsigned VOCAB_SIZE)¶ Constructor for the batched RNN language model.
- Parameters
model
: Model to hold all parameters for trainingLAYERS
: Number of layers of the RNNINPUT_DIM
: Embedding dimension for the wordsHIDDEN_DIM
: Dimension of the hidden statesVOCAB_SIZE
: Size of the input vocabulary
-
Expression
getNegLogProb
(const vector<vector<int>> &sents, unsigned id, unsigned bsize, unsigned &tokens, ComputationGraph &cg)¶ Computes the negative log probability on a batch.
- Return
- Expression for $ \(\sum_{s\in\mathrm{batch}}\log(p(s))\)
- Parameters
sents
: Full training setid
: Start index of the batchbsize
: Batch size (id
+bsize
should be smaller than the size of the dataset)tokens
: Number of tokens processed by the model (used for loos per token computation)cg
: Computation graph
-
void
RandomSample
(const dynet::Dict &d, int max_len = 150, float temp = 1.0)¶ Samples a string of words/characters from the model.
This can be used to debug and/or have fun. Try it on new datasets!
- Parameters
d
: Dictionary to use (should be same as the one used for training)max_len
: maximu number of tokens to generatetemp
: Temperature for sampling (the softmax computed is \(\frac{e^{\frac{r_t^{(i)}}{T}}}{\sum_{j=1}^{\vert V\vert}e^{\frac{r_t^{(j)}}{T}}}\)). Intuitively lower temperature -> less deviation from the distribution (= more “standard” samples)
Sequence to sequence models¶
Dynet is well suited for the variety of sequence to sequence models used in modern NLP. Here are some pre-coded structs implementing the most common one.
- template <class Builder>
-
struct
EncoderDecoder
¶ - #include <encdec.h>
This structure is a “vanilla” encoder decoder model.
This sequence to sequence network models the conditional probability \(p(y_1,\dots,y_m\vert x_1,\dots,x_n)=\prod_{i=1}^m p(y_i\vert \textbf{e},y_1,\dots,y_{i-1})\) where \(\textbf{e}=ENC(x_1,\dots,x_n)\) is an encoding of the input sequence produced by a recurrent neural network.
Typically \(\textbf{e}\) is the concatenated cell and output vector of a (multilayer) LSTM.
Sequence to sequence models were introduced in Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation .
Our implementation is more akin to the one from Sequence to sequence learning with neural networks .
- Template Parameters
Builder
: This can theoretically be any RNNbuilder. It’s only been tested with an LSTM as of now
Public Functions
-
EncoderDecoder
()¶ Default builder.
-
EncoderDecoder
(Model &model, unsigned num_layers, unsigned input_dim, unsigned hidden_dim, bool bwd = false)¶ Creates an EncoderDecoder.
- Parameters
model
: Model holding the parametersnum_layers
: Number of layers (same in the ecoder and decoder)input_dim
: Dimension of the word/char embeddingshidden_dim
: Dimension of the hidden statesbwd
: Set totrue
to make the encoder bidirectional. This doubles the number of parameters in the encoder. This will also add parameters for an affine transformation from the bidirectional encodings (of size num_layers * 2 * hidden_dim) to encodings of size num_layers * hidden_dim compatible with the decoder
-
Expression
encode
(const vector<vector<int>> &isents, unsigned id, unsigned bsize, unsigned &chars, ComputationGraph &cg)¶ Batched encoding.
Encodes a batch of sentences of the same size (don’t forget to pad them)
- Return
- Returns the expression for the negative (batched) encoding
- Parameters
isents
: Whole datasetid
: Index of the start of the batchbsize
: Batch sizechars
: Number of tokens processed (used to compute loss per characters)cg
: Computation graph
-
Expression
encode
(const vector<int> &insent, ComputationGraph &cg)¶ Single sentence version of
encode
Note : this just creates a trivial dataset and feed it to the batched version with batch_size 1. It’s not very effective so don’t use it for training.
- Return
- Expression of the encoding
- Parameters
insent
: Input sentencecg
: Computation graph
-
Expression
decode
(const Expression i_nc, const vector<vector<int>> &osents, int id, int bsize, ComputationGraph &cg)¶ Batched decoding.
[long description]
- Return
- Expression for the negative log likelihood
- Parameters
i_nc
: Encoding (should be batched)osents
: Output sentences datasetid
: Start index of the batchbsize
: Batch size (should be consistent with the shape ofi_nc
)cg
: Computation graph
-
Expression
decode
(const Expression i_nc, const vector<int> &osent, ComputationGraph &cg)¶ Single sentence version of
decode
For similar reasons as
encode
, this is not really efficient. USed the batched version directly for training- Return
- Expression for the negative log likelihood
- Parameters
i_nc
: Encodingosent
: Output sentencecg
: Computation graph
-
vector<int>
generate
(const vector<int> &insent, ComputationGraph &cg)¶ Generate a sentence from an input sentence.
Samples at each timestep ducring decoding. Possible variations are greedy decoding and beam search for better performance
- Return
- Generated sentence (indices in the dictionary)
- Parameters
insent
: Input sentencecg
: Computation Graph
-
vector<int>
generate
(Expression i_nc, unsigned oslen, ComputationGraph &cg)¶ Generate a sentence from an encoding.
You can use this directly to generate random sentences
- Return
- Generated sentence (indices in the dictionary)
- Parameters
i_nc
: Input encodingoslen
: Maximum length of outputcg
: Computation graph
Mode advanced topics are below:
Minibatching¶
Minibatching Overview¶
Minibatching takes multiple training examples and groups them together to be processed simultaneously, often allowing for large gains in computational efficiency due to the fact that modern hardware (particularly GPUs, but also CPUs) have very efficient vector processing instructions that can be exploited with appropriately structured inputs.
As shown in the figure below, common examples of this in neural networks include grouping together matrix-vector multiplies from multiple examples into a single matrix-matrix multiply, or performing an element-wise operation (such as tanh
) over multiple vectors at the same time as opposed to processing single vectors individually.

In most neural network toolkits, mini-batching is largely left to the user, with a bit of help from the toolkit. This is usually done by adding an additional dimension to the tensor that they are interested in processing, and ensuring that all operations consider this dimension when performing processing. This adds some cognitive load, as the user must keep track of this extra batch dimension in all their calculations, and also ensure that they use the correct ordering of the batch dimensions to achieve maximum computational efficiency. Users must also be careful when performing operations that combine batched and unbatched elements (such as batched hidden states of a neural network and unbatched parameter matrices or vectors), in which case they must concatenate vectors into batches, or “broadcast” the unbatched element, duplicating it along the batch dimension to ensure that there are no illegal dimension mismatches.
DyNet hides much of this complexity from the user through the use of specially designed batching operations which treat the number of mini-batch elements not as another standard dimension, but as a special dimension with particular semantics. Broadcasting is done behind the scenes by each operation implemented in DyNet, and thus the user must only think about inputting multiple pieces of data for each batch, and calculating losses using multiple labels.
First, let’s take a look at a non-minibatched example using the Python API.
In this example, we look up word embeddings word_1
and word_2
using lookup parameters E
.
We then perform an affine transform using weights W
and bias b
, and perform a softmax.
Finally, we calculate the loss given the true label out_label
.
# in_words is a tuple (word_1, word_2)
# out_label is an output label
word_1 = E[in_words[0]]
word_2 = E[in_words[1]]
scores_sym = W*dy.concatenate([word_1, word_2])+b
loss_sym = dy.pickneglogsoftmax(scores_sym, out_label)
Next, let’s take a look at the mini-batched version:
# in_words is a list [(word_{1,1}, word_{1,2}), (word_{2,1}, word_{2,2}), ...]
# out_labels is a list of output labels [label_1, label_2, ...]
word_1_batch = dy.lookup_batch(E, [x[0] for x in in_words])
word_2_batch = dy.lookup_batch(E, [x[1] for x in in_words])
scores_sym = W*dy.concatenate([word_1_batch, word_2_batch])+b
loss_sym = dy.sum_batches( dy.pickneglogsoftmax_batch(scores_sym, out_labels) )
We can see there are only 4 major changes: the word IDs need to be transformed into lists of IDs instead of a single ID, we need to call lookup_batch
instead of the standard lookup, we need to call pickneglogsoftmax_batch
instead of the unbatched version, and we need to call sum_batches
at the end to sum the loss from all the batches.
A full example of mini-batching in action for a recurrent neural language model can be found here for C++ and Python.
The Mini-batch Dimension¶
The way DyNet handles this is by using a special privileged “mini-batch element” dimension, which indicates the number of training examples in the mini-batch. To give an example from the C++ API, we can declare a Dim
object in C++
Dim d({2,4,8}, 16)
or Python
d = Dim([2,4,8], 16)
Here, 2,4,8
are the dimensions of the data in the tensor for each example, while 16
is the number of examples in the mini-batch. When we print out the dimensions (for example when calling the print_graphviz()
functionality for debugging, this will be print as {2,4,8x16}
.
Mini-batched Functions¶
For the great majority of standard operations, things should work seamlessly for minibatched elements. The one condition is that all inputs must have either one mini-batch element only, or the same number of mini-batch elements. So a binary function f(x,y)
could take inputs where the number of minibatch elements in x/y
are 1/1
, 4/1
, 1/4
, or 4/4
respectively. However, it is not possible to have different non-one numbers of minibatch elements, such as x/y
having minibatch sizes of 2/4
.
There are some operations where we need to explicitly think about batching, mostly on the input and output sides of the graph. These include input operations:
lookup()
(C++) andlookup_batch()
(Python): Performs lookup over a vector of input IDs, where each input ID is an element of the mini-batch.input()
: C++ input can specify aDim
object that is mini-batched. In Python, directly adding batched input is not supported yet, but there is a workaround <https://github.com/clab/dynet/issues/175> usingreshape()
.
Loss calculation operations:
pickneglogsoftmax()
(C++) andpickneglogsoftmax_batch()
(Python): Calculates the negative log softmax loss over multiple batch elements.hinge()
(C++): Similarly, calculate hinge loss over multiple elements.
Manipulation operations:
reshape()
: Can be used to reshape into tensors with a batch element of more than one.pick()
(C++) andpick_batch()
(Python): Picks an element for each of the mini-batch elements.sum_batches()
: Will sum together all of the values in the batch. This is often used to sum together the loss function befor performing the backward step.
Multi-processing¶
In addition to minibatch support, the DyNet C++ API also supports training models using many CPU cores (Python support is pending). This is particularly useful when performing training of networks that are not conducive to simple mini-batching, such as tree-structured networks.
DyNet abstracts most of the behind-the-scenes grit from the user.
The user defines a function to be called for each datum in the training data set, and passes this function, along with an array of data, to DyNet.
Internally, DyNet launches a pool of training processes and automatically handles passing data examples to each worker.
Each worker process individually processes a datum, computing the results of the forward and backward passes, computes gradients with respect to each parameter, and passes these results back to the parent process via a shared memory variable.
Whenever the parent process, which is also processing data, completes a gradient computation, it averages all of the gradients currently in the shared memory gradient storage and updates all parameters with respect to that average gradient.
In this way running training on n
cores is similar to training with a stochastic minibatch size with expected value of approximately n
.
This method is quite efficient, achieving nearly linear speedups with increasing numbers of cores, due to its lockless nature.
Examples of how to use the multi-processing API can be found in the xor-mp
and rnnlm-mp
sections of the examples/cpp
directory.
Unorthodox Design¶
There are a couple design decisions about DyNet that are different from the way things are implemented in other libraries, or different from the way you might expect things to be implemented. The items below are a list of these unorthodox design decisions, which you should read to avoid being surprised. We also try to give some justification for these decisions (although we realize that this is not the only right way to do things).
Sparse Updates¶
By default, DyNet parameter optimizers perform sparse updates over
LookupParameters
. This means that if you have a LookupParameters
object, use a certain subset of indices, then perform a parameter update, the
optimizer will loop over the used subset, and not perform any updates over
the unused values. This can improve efficiency in some cases: e.g. if you have
embeddings for a vocabulary of 100,000 words and you only use 5 of them in a
particular update, this will avoid doing updates over all 100,000. However,
there are two things to be careful of. First, this means that some update rules
such as ones using momentum such as MomentumSGDTrainer
and AdamTrainer
are not strictly correct (these could be made correct with some effort, but
this would complicate the programming interface, which we have opted against).
Also, on GPUs, because large operations are
relatively cheap, it can sometimes be faster to just perform a single operation
over all of the parameters, as opposed to multiple small operations. In this
case, you can set the sparse_updates_enabled
variable of your Trainer
to false
, and DyNet will perform a standard dense update, which is
guaranteed to be exactly correct, and potentially faster on GPU.
Weight Decay¶
As described in the Command Line Options, weight decay is implemented
through the option --dynet-weight-decay
. If this value is set to wd
,
each parameter in the model is multiplied by (1-wd)
after every parameter
update. This weight decay is similar to L2 regularization, and is equivalent in
the case of using simple SGD (SimpleSGDTrainer
), but it is not the same
when using any other optimizers such as AdagradTrainer
or AdamTrainer
.
You can still try to use weight decay with these optimizers, and it might work,
but if you really want to correctly apply L2 regularization with these
optimizers, you will have to directly calculate the L2 norm of each of the
parameters and add it to the objective function before performing your update.
Minibatching Implementation¶
Minibatching in DyNet is different than how it is implemented in other libraries. In other libraries, you can create minibatches by explicitly adding another dimension to each of the variables that you want to process, and managing them yourself. Instead, DyNet provides special Operations that allow you to perform input, lookup, or loss calculation over mini-batched input, then DyNet will handle the rest. The programming paradigm is a bit different from other toolkits, and may take a bit of getting used to, but is often more convenient once you’re used to it.
LSTM Implementation¶
The implementation of LSTMs in LSTMBuilder
is not the canonical
implementation, but an implementation using coupled input and forget gates, as
described in “LSTM: A Search Space Odyssey” (https://arxiv.org/abs/1503.04069).
In other words, if the value of the input gate is i, the forget gate is 1-i.
This reduces the number of parameters in the model and speeds training a little,
and in many cases the accuracy is the same or better. If you want to try the
standard version of the LSTM, use the VanillaLSTMBuilder
class.
Dropout Scaling¶
When using dropout to help prevent overfitting, dropout is generally applied at training time, then at test time all the nodes in the neural net are used to make the final decision, increasing robustness. However, because there is a disconnect between the number of nodes being used in each situation, it is important to scale the values of the output to ensure that they match in both situations. There are two ways to do this:
- Vanilla Dropout: At training time, perform dropout with probability
p
. At test time, scale the outputs of each node byp
. - Inverted Dropout: At training time, perform dropout with probability
p
, and scale the outputs by1/p
. At test time, use the outputs as-is.
The first is perhaps more common, but the second is convenient, because we only need to think about dropout at training time, and thus DyNet opts to use the latter. See here for more details on these two methods.
And we welcome your contributions!
Contributing to Dynet¶
You are very welcome to contribute to Dynet, be it to correct a bug or add a feature. Here are some guidelines to guarantee consistency.
Coding Tips and Style¶
Coding Tips¶
Adding New Operations: One of the most common things that one will want to do to modify DyNet is to add a new operation to calculate a new function. You can find more information on how to do so at the end of the tutorial slides here.
Coding Practices¶
Testing:
Before committing any code, tests should be run to make sure that the new code didn’t break anything.
This can be done by using the make test
command.
It is also highly recommended that you add unit tests for any new functionality.
Unit tests are implemented in the tests
directory.
When making a bug fix, you can add a test that broke before the fix but passes afterwards.
That being said, tests are not an absolute requirement, so if you have a contribution but aren’t sure how to do tests, please don’t let this stop you from contributing.
Coding Style Conventions¶
DyNet (the main version in C++) has certain coding style standards:
Overall Philosophy: DyNet is designed to minimize the computational overhead when creating networks. Try to avoid doing slow things like creating objects or copying memory in places that will be called frequently during computation graph construction.
Function Names: Function names are written in “snake_case”.
const: Always use const if the input to a function is constant.
Pointer vs. Reference: When writing functions, use the following guidelines (quoted from here):
- Only pass a value by pointer if the value 0/NULL is a valid input in the current context.
- If a function argument is an out-value, then pass it by reference.
- Choose “pass by value” over “pass by const reference” only if the value is a POD (Plain Old Datastructure) or small enough (memory-wise) or in other ways cheap enough (time-wise) to copy.
Error handling: The C++ core of DyNet provides a mechanism for error handling that
should be used in all code. It consists of 3 macros as follows (included in globals.h
):
DYNET_INVALID_ARG(msg)
: This is used to throw an error that is triggered when a user passes an invalid argument to one of the functions.DYNET_RUNTIME_ERR(msg)
: This is used to throw an error that could be triggered by a user, but is not the result of an invalid argument. For example, it could be used when something is not implemented yet, or when the program dies due to lack of memory, etc.DYNET_ASSERT(expr,msg)
: This is to be used to check things that should only happen due to a programming error within DyNet itself, and should never be triggered by a user.expr
is a condition, andmsg
is a message explaining the exception, withostream
-style formatting.
Documentation¶
Dynet uses Doxygen for commenting the code and Sphinx for the general documentation.
If you’re only documenting features you don’t need to concern yourself with Sphinx, your doxygen comments will be integrated in the documentation automatically.
Doxygen guidelines¶
Please document any publicly accessible function you write using the doxygen syntax.
You can see examples in the training file. The most important thing is to use /*
style comments and \command
style commands.
For ease of access the documentation is divided into groups. For now the groups are optimizers and operations. If you implement a function that falls into one of these groups, add \ingroup [group name]
at the beginning of your comment block.
If you want to create a group, use \defgroup [group-name]
at the beginning of your file. Then create a file for this group in sphinx (see next section).
Important : You can use latex in doxygen comments with the syntax \f$ \f$
. For some reason since readthedocs updated their version of sphinx \f[ \f]
doesn’t work anymore so don’t use it it breaks the build.
Sphinx guidelines¶
The sphinx source files are located in doc/source
. They describe the documentation’s organization using the reStructuredText Markup language.
Although reStructuredText is more powerful than Markdown it might feel less intuitive, especially when writing long documents. If needs be you can write your doc in Markdown and convert it using Pandoc.
For a tutorial on Sphinx see their tutorial.
Doxygen generated XML is integrated in sphinx files using the Breathe module. The only breathe command used now is doxygengroup
. You shouldn’t used commands for individual classes/functions/structs without a good reason. Most information should be put in the doxygen comments.
Building the docs¶
The documentation is automatically rebuilt by ReadTheDocs each time you push on Github.
If you want to build the documentation locally you’ll need to install doxygen, sphinx and breathe and then run build_doc.sh
from the doc
folder.