Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dido1998 committed Aug 19, 2019
1 parent 07fb9ff commit 2767db3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
3 changes: 1 addition & 2 deletions chainerx_cc/chainerx/routines/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 66,8 @@ std::vector<Array> ExtractGates(const Array& x) {
Shape shape{shape_vec};
Array x_r = Reshape(x, shape);
std::vector<Array> gates = Split(x_r, 4, 2);
int index = 0;
for (auto& gate : gates) {
gates[index ] = Squeeze(gate);
gate = Squeeze(gate);
}
return gates;
}
Expand Down
10 changes: 4 additions & 6 deletions chainerx_cc/chainerx/routines/n_step_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 154,18 @@ std::vector<std::vector<Array>> OneDirectionalLoop(
}
std::vector<Array> h_list;
for (auto& x : xs) {
Array x_t = x;

if (x_t.shape()[0] > h.shape()[0]) {
throw DimensionError{"The batch size of x must be equal to or less than the size of state", x_t.shape(), ' ', h.shape()};
if (x.shape()[0] > h.shape()[0]) {
throw DimensionError{"The batch size of x must be equal to or less than the size of state", x.shape(), ' ', h.shape()};
}
std::vector<int64_t> indices_h;
indices_h.emplace_back(x_t.shape()[0]);
indices_h.emplace_back(x.shape()[0]);
indices_h.emplace_back(h.shape()[0]);
std::vector<Array> h_split = Split(h, indices_h, 0);
std::vector<Array> c_split;
std::vector<Array> h_c;
if (c.has_value()) {
std::vector<int64_t> indices_c;
indices_c.emplace_back(x_t.shape()[0]);
indices_c.emplace_back(x.shape()[0]);
indices_c.emplace_back(c->shape()[0]);
c_split = Split(*c, indices_c, 0);
h_c = impl(x, h_split[0], c_split[0], ws, b, activation);
Expand Down

0 comments on commit 2767db3

Please sign in to comment.