DEV Community

freiberg-roman
freiberg-roman

Posted on

Speed Up Your PyTorch Development Using Types

Types in Python are optional, but for any PyTorch project exceeding 1,000 lines of code, implementing them can be highly beneficial.

You've probably encountered a scenario where, after coding a model with several networks and initiating training, you realize there's a mismatch in data dimensions or type. Types help clarify the information flow. Moreover, they enable code completion tools to suggest more appropriate options.

While types won't eliminate debugging, they do make the process more focused.

Python Types

First, check if your editor or IDE supports static type checking. If not, a quick online search will reveal popular options like Mypy or Pyright.

Python types are akin to a simplified version of the typing systems in Java or C++.

Here's how the syntax looks:

variable: int = 10
Enter fullscreen mode Exit fullscreen mode

For functions, specify the types of parameters and the return type:

def function(x: float, y: float) -> float:
    return x ** y
Enter fullscreen mode Exit fullscreen mode

Below is a list of the most fundamental Python types you'll encounter:

int    # Whole number integers
float  # Floating point numbers
bool   # Boolean values (True or False)
str    # Strings
list   # Ordered mutable lists
tuple  # Ordered immutable lists
dict   # Dictionaries of key-value pairs
Enter fullscreen mode Exit fullscreen mode

For the latter three, use the Python typing package for more precise type specifications:

from typing import Dict, List, Tuple
Enter fullscreen mode Exit fullscreen mode

Example usages:

a_dict: Dict[str, float] = {'a': 1.0, ...}
a_list: List[int] = [1, 2, 3]
a_tuple: Tuple[int, bool, float] = (1, True, 3.141)

def forward(input: List[float]) -> Dict[str, float]:
    ...
Enter fullscreen mode Exit fullscreen mode

For dealing with uncertain types, Any is a versatile type that accepts any value. Use it sparingly to maintain the effectiveness of your type checker.

The Self type represents the current class:

import numpy as np
from typing import Self, List, Tuple

@dataclass
class PointCloud:
    points: List[Tuple[float, float, float]]

    def from_numpy(arr: np.ndarray) -> Self:
        ...
Enter fullscreen mode Exit fullscreen mode

Frequently recurring patterns in your codebase may warrant the creation of custom types from basic types:

from typing import Dict, List

SomeType = Dict[str, List[int]]
var: SomeType = ...
Enter fullscreen mode Exit fullscreen mode

However, using expressive classes as types is often more sensible:

pcd: PointCloud = PointCloud([(1.0, 1.0, 1.0), ...])  # example from above
Enter fullscreen mode Exit fullscreen mode

Be aware that your static type checker can often automatically identify the correct class upon assignment. However, use explicit types in scenarios where automatic deduction is not possible.

Types for PyTorch

A clean architecture combined with proper types elevates the professionalism of your PyTorch codebases significantly.

Key base classes in PyTorch you should be aware of:

torch.Tensor
torch.optim.Optimizer
torch.utils.data.Dataset
torch.nn.Module
Enter fullscreen mode Exit fullscreen mode

For more details, refer to the official PyTorch documentation or explore the PyTorch codebase using your code editor's symbol search feature.

Dataclasses are extremely useful for defining types. They allow for quick class definitions intended for data holding and validation, usually without an __init__ method. Validation can be implemented using the __post_init__ method.

Example:

import torch
from dataclasses import dataclass

@dataclass
class PointCloud:
    points: torch.Tensor
    colors: torch.Tensor

    def __post_init__(self):
        assert self.points.shape == torch.Size([...])
        ...
Enter fullscreen mode Exit fullscreen mode

Employing these types for network inputs and outputs significantly reduced my debugging time. I hope they will prove just as beneficial for you.

Common Troubleshooting for DataLoaders

If you adopt a strict dataclass typing strategy in PyTorch, you might face a unique challenge. The PyTorch DataLoader typically expects data to consist of Tensors or iterable collections of Tensors. To resolve this, reimplement the default collate_fn and provide it to the DataLoader. Here's an example:

@dataclass
class PointCloud:
    position: torch.Tensor
    color: torch.Tensor

def custom_collate_fn(batch: List[PointCloud]):
    positions = torch.stack([pcd.position for pcd in batch])
    colors = torch.stack([pcd.color for pcd in batch])
    return PointCloud(positions, colors)
...
loader = DataLoader(..., collate_fn=custom_collate_fn)
Enter fullscreen mode Exit fullscreen mode

For distinguishing between batched and unbatched types, I usually rely on simple shape assertions.

With these minor adjustments, you'll find that your types work harmoniously with the PyTorch framework.

Let me know if you use types in your code and how they've benefited your projects.

Top comments (0)