본문 바로가기
Development/Torch

[PyTorch] How to hook '.to()' or '.cuda()' method when CPU and CUDA implementations of a module are different.

by kail9974 2021. 10. 5.

When implement custom operation like upfirdn2d, the operation of the method maybe different between default and CUDA implementation.

 

import torch
import torch.nn as nn

from torch import Tensor
from typing import Callable


class Foo(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.default_operation = self._load_default_operation()
        # cuda operation may need to be compiled, so set the initial value to
        # None to eliminate unnecessary compilation time.
        self.cuda_operation = None
        self.operation = self.default_operation

    def _load_default_operation(self) -> Callable[..., Tensor]:
        def default_operation(input: Tensor) -> str:
            return 'default'
        return default_operation

    def _load_cuda_operation(self) -> Callable[..., Tensor]:
        def cuda_operation(input: Tensor) -> str:
            return 'cuda'
        return cuda_operation

    def forward(self, input: Tensor) -> str:
        return self.operation(input)

Here is `Foo` module of custom implementation.

If want to change operation between default_operation and cuda_operation, need to hook the `.to()` or `.cuda()`, `.cpu()` method.

These methods operate using `._apply(fn)` in nn.Module.

So, the solution is override the `._apply(fn)` method.

 

# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

class Module:
    ...

    def _apply(self, fn):
        for module in self.modules():
            module._apply(fn)

        def compute_should_use_set_data(tensor, tensor_applied):
            ...

        for key, param in self._parameters.items():
            if param is not None:
                with torch.no_grad():
                    param_applied = fn(param)
                ...

        for key, buf in self._buffers.items():
            if buf is not None:
                self._buffers[key] = fn(buf)

        return self

    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
        return self._apply(lambda t: t.cuda(device))

    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
            *args, **kwargs)

        if dtype is not None:
            if not dtype.is_floating_point:
                raise TypeError('nn.Module.to only accepts floating point '
                                'dtypes, but got desired dtype={}'.format(dtype))

        def convert(t):
            if convert_to_format is not None and t.dim() == 4:
                return t.to(device, dtype if t.is_floating_point() else None,
                            non_blocking, memory_format=convert_to_format)
            return t.to(device, dtype if t.is_floating_point() else None,
                        non_blocking)

        return self._apply(convert)

The `_apply` method recursively calls all children that are Tensors.

`_apply` method does not call child method that operates on the nn.Module, so it is difficult to make a conditional (acting differently depending on the device) method call.

 

The argument of `_apply` is only `fn` and the function `fn` of `.to` or `.cuda` has only argument `t` which means Tensor.

We will check whether the `fn` passed as an argument has `t` as the variable name.

Use function.__code__.co_varnames.

And pass an empty Tensor to check the device.

 

class Foo(nn.Module):
    ...

    def _apply(self, fn: Callable[..., Any]):
        if 't' in fn.__code__.co_varnames:
            empty = torch.empty(0)
            device = fn(empty).device  # type: ignore
            if is_cuda(device):
                if self.cuda_operation is None:
                    self.cuda_operation = self._load_cuda_operation()
                self.operation = self.cuda_operation
            else:
                self.operation = self.default_operation
        return super()._apply(fn)

def is_cuda(device: Union[torch.device, str, None]) -> bool:
    cuda_available = torch.cuda.is_available()
    if device is None:
        if cuda_available:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
    elif isinstance(device, str):
        device = torch.device(device)
    return cuda_available and device.type == 'cuda'

The implementation is above.

 

>> f = Foo()
>> t = torch.empty(0)
>> f(t)
'default'
>> f.cuda()(t)
'cuda'
>> f.cpu()(t)
'default'

And this is the test result of the module.

댓글