| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Custom replacement for `torch.nn.functional.grid_sample` that |
| supports arbitrarily high order gradients between the input and output. |
| Only works on 2D images and assumes |
| `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" |
|
|
| import torch |
|
|
| |
| |
| |
|
|
| |
|
|
| enabled = False |
|
|
| |
|
|
|
|
| def grid_sample(input, grid): |
| if _should_use_custom_op(): |
| return _GridSample2dForward.apply(input, grid) |
| return torch.nn.functional.grid_sample(input=input, |
| grid=grid, |
| mode='bilinear', |
| padding_mode='zeros', |
| align_corners=False) |
|
|
|
|
| |
|
|
|
|
| def _should_use_custom_op(): |
| return enabled |
|
|
|
|
| |
|
|
|
|
| class _GridSample2dForward(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, input, grid): |
| assert input.ndim == 4 |
| assert grid.ndim == 4 |
| output = torch.nn.functional.grid_sample(input=input, |
| grid=grid, |
| mode='bilinear', |
| padding_mode='zeros', |
| align_corners=False) |
| ctx.save_for_backward(input, grid) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| input, grid = ctx.saved_tensors |
| grad_input, grad_grid = _GridSample2dBackward.apply( |
| grad_output, input, grid) |
| return grad_input, grad_grid |
|
|
|
|
| |
|
|
|
|
| class _GridSample2dBackward(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, grad_output, input, grid): |
| op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') |
| grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) |
| ctx.save_for_backward(grid) |
| return grad_input, grad_grid |
|
|
| @staticmethod |
| def backward(ctx, grad2_grad_input, grad2_grad_grid): |
| _ = grad2_grad_grid |
| grid, = ctx.saved_tensors |
| grad2_grad_output = None |
| grad2_input = None |
| grad2_grid = None |
|
|
| if ctx.needs_input_grad[0]: |
| grad2_grad_output = _GridSample2dForward.apply( |
| grad2_grad_input, grid) |
|
|
| assert not ctx.needs_input_grad[2] |
| return grad2_grad_output, grad2_input, grad2_grid |
|
|
|
|
| |
|
|