Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSTM, GRU and RNN implementation for ChainerX #7764

Merged
merged 67 commits into from
Aug 19, 2019
Merged
Changes from 1 commit
Commits
Show all changes
67 commits
Select commit Hold shift click to select a range
c1ed7a5
implementing simple lstm function
dido1998 May 16, 2019
f88460e
implementation lstm activation function
dido1998 May 18, 2019
f6954d2
implementation of n_step_lstm
dido1998 May 25, 2019
2f142de
fix formatting
dido1998 May 25, 2019
0db2e42
fix formatting
dido1998 May 25, 2019
3124124
merge with master
dido1998 May 25, 2019
a11cb63
fix compile errors
dido1998 May 25, 2019
7adc07e
fix python formatting
dido1998 May 25, 2019
3830e68
fix errors in test_connection.py
dido1998 May 25, 2019
c449375
fix formatting
dido1998 May 25, 2019
e6755ca
fix syntax errors
dido1998 May 25, 2019
0a77595
fix return error
dido1998 May 25, 2019
b66996b
fix c formatting
dido1998 May 25, 2019
5692af0
implemetation of n_step_bi_lstm
dido1998 May 26, 2019
d01d703
fix c formatting
dido1998 May 27, 2019
b0ee988
added cuddn rnn forward implementation
dido1998 Jun 6, 2019
14b2d0b
cudnn rnn implementation
dido1998 Jun 13, 2019
405bc36
remove unnecessary debug messages
dido1998 Jun 13, 2019
6736a02
adding remaining files
dido1998 Jun 14, 2019
3a0ec64
Merge branch 'LSTM' of https://github.com/dido1998/chainer into LSTM
dido1998 Jun 14, 2019
e9cdba8
cudnn code fixed
dido1998 Jun 19, 2019
e5ab97b
adding incomplete test code for cuda LSTM
dido1998 Jun 20, 2019
74085fa
completed cudnn error-free code
dido1998 Jun 23, 2019
6d48e97
changing usage of AsContiguous
dido1998 Jun 24, 2019
da42d96
adding deleted file
dido1998 Jun 27, 2019
87c7678
using the same descriptor for forward and backward pass
dido1998 Jun 27, 2019
fdcaa1a
merge with chainerx master
dido1998 Jun 28, 2019
9914ee8
fixing compile errors
dido1998 Jun 28, 2019
f38147a
working backward pass for cudnn n_step_lstm
dido1998 Jun 29, 2019
ca6a79d
fixing cpplint errors
dido1998 Jun 29, 2019
c6b726b
fixing clang errors
dido1998 Jun 29, 2019
a03b9a1
fix clang errors
dido1998 Jun 29, 2019
84ed04d
adding GRU base function
dido1998 Jun 30, 2019
3cf04fc
native implementation of n_step_gru and n_step_bigru
dido1998 Jul 2, 2019
91a9b8c
cudnn implementation of n_step_gru and n_step_bigru
dido1998 Jul 3, 2019
1a363d2
fix clang errors
dido1998 Jul 3, 2019
9cb248a
Merge pull request #18 from chainer/master
dido1998 Jul 4, 2019
e8d9135
fix incorrect include
dido1998 Jul 4, 2019
8e1092a
fix clang errors
dido1998 Jul 4, 2019
698ee3c
Merge branch 'GRU' of https://github.com/dido1998/chainer into GRU
dido1998 Jul 4, 2019
06fcd9c
fix incorrect include
dido1998 Jul 4, 2019
0993771
minor fixes
dido1998 Jul 4, 2019
5c2a912
fix minor clang error
dido1998 Jul 4, 2019
5653d91
fix minor clang errors
dido1998 Jul 4, 2019
0243f76
vanilla rnn implementation
dido1998 Jul 14, 2019
5fbf512
minor changes
dido1998 Jul 15, 2019
3022131
fix clang errors
dido1998 Jul 15, 2019
d129f81
changing tolerance for float16 and handling non-differentiable point …
dido1998 Jul 18, 2019
5623a03
adding documentation and making minor changes
dido1998 Aug 1, 2019
5b2ea53
fix clang errors
dido1998 Aug 5, 2019
233026b
Merge pull request #30 from chainer/master
dido1998 Aug 7, 2019
1cec24f
integrating chainer with chainerx
dido1998 Aug 7, 2019
2fc9110
Merge pull request #33 from chainer/master
dido1998 Aug 8, 2019
8063ba4
integrating chainerx into chainer for GRU
dido1998 Aug 8, 2019
5a95348
integrating chainerx into chainer for lstm
dido1998 Aug 8, 2019
fd6347c
fix clang errors
dido1998 Aug 13, 2019
3dbce9d
Merge pull request #34 from chainer/master
dido1998 Aug 13, 2019
4b7295b
integrating chainer with chainerx for n_step_lstm
dido1998 Aug 13, 2019
3c5b822
converting names to camelcase and minor changes
dido1998 Aug 15, 2019
4911c33
making cosmetic changes
dido1998 Aug 15, 2019
74a7346
fix clang errors
dido1998 Aug 15, 2019
83a1c6c
fix clang errors
dido1998 Aug 16, 2019
becf72c
Merge pull request #36 from chainer/master
dido1998 Aug 16, 2019
9d292f7
cosmetic fixes
dido1998 Aug 17, 2019
07fb9ff
Merge branch 'VanillaRNN' of https://github.com/dido1998/chainer into…
dido1998 Aug 17, 2019
2767db3
minor changes
dido1998 Aug 19, 2019
a8d9806
fix clang-tidy errors
dido1998 Aug 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
integrating chainer with chainerx for n_step_lstm
  • Loading branch information
dido1998 committed Aug 13, 2019
commit 4b7295b2c54233e6a97fb90bcf72882023b1e149
95 changes: 93 additions & 2 deletions chainer/functions/rnn/n_step_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 2,7 @@

import chainer
from chainer import backend
from chainer import variable
from chainer.backends import cuda
from chainer.functions.array import reshape
from chainer.functions.array import stack
Expand All @@ -16,6 17,70 @@
cudnn = cuda.cudnn


def _extract_apply_in_data(inputs):
if not inputs:
return False, ()

if chainerx.is_available():
has_chainerx_array = False

# Unwrap arrays
arrays = []
for x in inputs:
if isinstance(x, variable.Variable):
if x._has_chainerx_array:
arrays.append(x._data[0])
has_chainerx_array = True
else:
arrays.append(x.array)
else: # x is ndarray
arrays.append(x)
if not has_chainerx_array:
if isinstance(x, chainerx.ndarray):
has_chainerx_array = True
return has_chainerx_array, tuple(arrays)
else:
return False, tuple([
x.array if isinstance(x, variable.Variable) else x
for x in inputs])


def _combine_inputs(hx, cx, ws, bs, xs, num_layers, directions):
combined = []
combined.append(hx)
combined.append(cx)
for x in xs:
combined.append(x)

for n in range(num_layers):
for direction in range(directions):
idx = directions * n direction

for i in range(8):
combined.append(ws[idx][i])
for i in range(8):
combined.append(bs[idx][i])
return combined


def _seperate_inputs(combined, num_layers, seq_length, directions):
hx = combined[0]
cx = combined[1]
xs = combined[2: 2 seq_length]
ws = []
bs = []
index = 2 seq_length
for n in range(num_layers):
ws.append(combined[index: index 8])
bs.append(combined[index 8: index 16])
index = 16
if directions == 2:
ws.append(combined[index: index 8])
bs.append(combined[index 8: index 16])
index = 16
return hx, cx, ws, bs, xs


def _stack_weight(ws):
# TODO(unno): Input of the current LSTM implementation is shuffled
w = stack.stack(ws, axis=1)
Expand Down Expand Up @@ -416,11 481,37 @@ def n_step_lstm_base(

xp = backend.get_array_module(hx, hx.data)

# TODO(imanishi): Support ChainerX n_step_rnn
use_cuda = xp is cuda.cupy or (
xp is chainerx and hx.device.device.backend.name == 'cuda')

if use_cuda and chainer.should_use_cudnn('>=auto', 5000):
directions = 1
if use_bi_direction:
directions = 2

combined = _combine_inputs(hx, cx, ws, bs, xs, n_layers, directions)
has_chainerx_array, combined = _extract_apply_in_data(combined)
hx_chx, cx_chx, ws_chx, bs_chx, xs_chx = _seperate_inputs(
combined, n_layers, len(xs), directions)

if has_chainerx_array and xp is chainerx and dropout_ratio == 0:
if use_bi_direction:
hy, cy, ys = chainerx.n_step_bilstm(
n_layers, hx_chx, cx_chx, ws_chx, bs_chx, xs_chx)
else:
hy, cy, ys = chainerx.n_step_lstm(
n_layers, hx_chx, cx_chx, ws_chx, bs_chx, xs_chx)
hy = variable.Variable._init_unchecked(
hy, requires_grad=hy.is_backprop_required(),
is_chainerx_array=True)
cy = variable.Variable._init_unchecked(
cy, requires_grad=cy.is_backprop_required(),
is_chainerx_array=True)
ys = [variable.Variable._init_unchecked(
y, requires_grad=y.is_backprop_required(),
is_chainerx_array=True)
for y in ys]
return hy, cy, ys
elif use_cuda and chainer.should_use_cudnn('>=auto', 5000):
lengths = [len(x) for x in xs]
xs = chainer.functions.concat(xs, axis=0)
with chainer.using_device(xs.device):
Expand Down