TreeLSTM unit as an activation function.

This function implements TreeLSTM units both for N-ary TreeLSTM and Child-Sum TreeLSTM. Let the children cell states \(c_{\text{1}}, c_{\text{2}}, \dots, c_{\text{N}}\), and the incoming signal \(x\).

First, the incoming signal \(x\) is split into (3 + N) arrays \(a, i, o, f_{\text{1}}, f_{\text{2}}, ..., f_{\text{N}}\) of the same shapes along the second axis. It means that \(x\) ‘s second axis must have (3 + N) times of the length of each \(c_{n}\).

The splitted input signals are corresponding to:

  • \(a\) : sources of cell input
  • \(i\) : sources of input gate
  • \(o\) : sources of output gate
  • \(f_{n}\) : sources of forget gate for n-th ary

Second, it computes outputs as:

\[\begin{split}c &= \tanh(a) \text{sigmoid}(i) \\ & + c_{\text{1}} \text{sigmoid}(f_{\text{1}}), \\ & + c_{\text{2}} \text{sigmoid}(f_{\text{2}}), \\ & + ..., \\ & + c_{\text{N}} \text{sigmoid}(f_{\text{N}}), \\ h &= \tanh(c) \text{sigmoid}(o).\end{split}\]

These are returned as a tuple of (N + 1) variables.

Parameters:inputs (list of Variable) – Variable arguments which include all cell vectors from child-nodes, and an input vector. Each of the cell vectors and the input vector is Variable. The input vector must have the second dimension whose size is (N + 3) times of that of each cell, where N denotes the total number of cells.
Returns:Two Variable objects c and h. c is the updated cell state. h indicates the outgoing signal.
Return type:tuple

See the papers for details: Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks and A Fast Unified Model for Parsing and Sentence Understanding.

Tai et al.’s N-Ary TreeLSTM is little extended in Bowman et al., and this link is based on the variant by Bowman et al. Specifically, eq. 10 in Tai et al. only has one \(W\) matrix to be applied to \(x\), consistently for all children. On the other hand, Bowman et al.’s model has multiple matrices, each of which affects the forget gate for each child’s cell individually.


Assuming y is the current input signal, c is the previous cell state, and h is the previous output signal from an tree_lstm() function. Each of y, c and h has n_units channels. Using 2-ary (binary) TreeLSTM, most typical preparation of x is:

>>> model = chainer.Chain()
>>> with model.init_scope():
...   model.w = L.Linear(10, 5 * 10)
...   model.v1 = L.Linear(10, 5 * 10)
...   model.v2 = L.Linear(10, 5 * 10)
>>> y = np.random.uniform(-1, 1, (4, 10)).astype(np.float32)
>>> h1 = np.random.uniform(-1, 1, (4, 10)).astype(np.float32)
>>> h2 = np.random.uniform(-1, 1, (4, 10)).astype(np.float32)
>>> c1 = np.random.uniform(-1, 1, (4, 10)).astype(np.float32)
>>> c2 = np.random.uniform(-1, 1, (4, 10)).astype(np.float32)
>>> x = model.w(y) + model.v1(h1) + model.v2(h2)
>>> c, h = F.tree_lstm(c1, c2, x)

It corresponds to calculate the input sources \(a, i, o, f_{\text{1}}, f_{\text{2}}\) from the current input y and the children’s outputs h1 and h2. Different parameters are used for different kind of input sources.