chainer.functions.gumbel_softmax

chainer.functions.gumbel_softmax(log_pi, tau=0.1, axis=1)[source]

Gumbel-Softmax sampling function.

This function draws samples \(y_i\) from Gumbel-Softmax distribution,

\[y_i = {\exp((g_i + \log\pi_i)/\tau) \over \sum_{j}\exp((g_j + \log\pi_j)/\tau)},\]

where \(\tau\) is a temperature parameter and \(g_i\) s are samples drawn from Gumbel distribution \(Gumbel(0, 1)\)

See Categorical Reparameterization with Gumbel-Softmax.

Parameters
Returns

Output variable.

Return type

Variable