DEV Community

Lam
Lam

Posted on

PyTorch Quick Ref

Imports

import torch                                        # root package
from torch.utils.data import Dataset, DataLoader    # dataset representation and loading
Enter fullscreen mode Exit fullscreen mode

Neural Network API

import torch.autograd as autograd         # computation graph
from torch import Tensor                  # tensor node in the computation graph
import torch.nn as nn                     # neural networks
import torch.nn.functional as F           # layers, activations and more
import torch.optim as optim               # optimizers e.g. gradient descent, ADAM, etc.
from torch.jit import script, trace       # hybrid frontend decorator and tracing jit
Enter fullscreen mode Exit fullscreen mode

TorchScript and JIT

torch.jit.trace()         # takes your module or function and an example
                          # data input, and traces the computational steps
                          # that the data encounters as it progresses through the model

@script                   # decorator used to indicate data-dependent
                          # control flow within the code being traced
Enter fullscreen mode Exit fullscreen mode

ONNX

torch.onnx.export(model, dummy data, xxxx.proto)       # exports an ONNX formatted
                                                       # model using a trained model, dummy
                                                       # data and the desired file name

model = onnx.load("alexnet.proto")                     # load an ONNX model
onnx.checker.check_model(model)                        # check that the model
                                                       # IR is well formed

onnx.helper.printable_graph(model.graph)               # print a human readable
                                                       # representation of the graph
Enter fullscreen mode Exit fullscreen mode

References:

Top comments (0)