DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

GELU and Mish in PyTorch

Buy Me a Coffee

*Memos:

GELU() can get the 0D or more D tensor of the zero or more values computed by GELU function from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • The 1st argument for initialization is approximate(Optional-Default:'none'-Type:str): *Memos:
    • 'none' or 'tanh' can be selected.
    • The results of 'none' or 'tanh' are almost the same.
  • The 1st argument is input(Required-Type:tensor of float).
  • 'none': Image description
  • 'tanh': Image description
import torch
from torch import nn

my_tensor = torch.tensor([8., -3., 0., 1., 5., -2., -1., 4.])

gelu = nn.GELU()
gelu(input=my_tensor)
# tensor([8.0000e+00, -4.0499e-03, 0.0000e+00, 8.4134e-01,
#         5.0000e+00, -4.5500e-02, -1.5866e-01, 3.9999e+00])

gelu
# GELU(approximate='none')

gelu.approximate
# False

gelu = nn.GELU(approximate='tanh')
gelu(input=my_tensor)
# tensor([8.0000e+00, -3.6374e-03, 0.0000e+00, 8.4119e-01,
#         5.0000e+00, -4.5402e-02, -1.5881e-01, 3.9999e+00])

my_tensor = torch.tensor([[8., -3., 0., 1.],
                          [5., -2., -1., 4.]])
gelu = nn.GELU()
gelu(input=my_tensor)
# tensor([[8.0000e+00, -4.0499e-03, 0.0000e+00, 8.4134e-01],
#         [5.0000e+00, -4.5500e-02, -1.5866e-01, 3.9999e+00]])

my_tensor = torch.tensor([[[8., -3.], [0., 1.]],
                          [[5., -2.], [-1., 4.]]])
gelu = nn.GELU()
gelu(input=my_tensor)
# tensor([[[8.0000e+00, -4.0499e-03], [0.0000e+00, 8.4134e-01]],
#         [[5.0000e+00, -4.5500e-02], [-1.5866e-01, 3.9999e+00]]])
Enter fullscreen mode Exit fullscreen mode

Mish() can get the 0D or more D tensor of the zero or more values computed by Mish function from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • The 1st argument for initialization is inplace(Optional-Default:False-Type:bool): *Memos:
    • It does in-place operation.
    • Keep it False because it's problematic with True.
  • The 1st argument is input(Required-Type:tensor of float).

Image description

import torch
from torch import nn

my_tensor = torch.tensor([8., -3., 0., 1., 5., -2., -1., 4.])

mish = nn.Mish()
mish(input=my_tensor)
# tensor([8.0000, -0.1456, 0.0000, 0.8651, 4.9996, -0.2525, -0.3034, 3.9974])

mish
# Mish()

mish.inplace
# False

mish = nn.Mish(inplace=True)
mish(input=my_tensor)
# tensor([8.0000, -0.1456, 0.0000, 0.8651, 4.9996, -0.2525, -0.3034, 3.9974])

my_tensor = torch.tensor([[8., -3., 0., 1.],
                          [5., -2., -1., 4.]])
mish = nn.Mish()
mish(input=my_tensor)
# tensor([[8.0000, -0.1456, 0.0000, 0.8651],
#         [4.9996, -0.2525, -0.3034, 3.9974]])

my_tensor = torch.tensor([[[8., -3.], [0., 1.]],
                          [[5., -2.], [-1., 4.]]])
mish = nn.Mish()
mish(input=my_tensor)
# tensor([[[8.0000, -0.1456], [0.0000, 0.8651]]
#         [[4.9996, -0.2525], [-0.3034, 3.9974]]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)