from functools import reduce
import operator
import torch
import torch.nn.functional as F
from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads
from .expanded_weights_utils import standard_kwargs, \
    forward_helper, set_grad_sample_if_exists, unpack_expanded_weight_or_tensor
from typing import List, Optional

@implements_per_sample_grads(F.group_norm)
class GroupNormPerSampleGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
        expanded_args, expanded_kwargs = standard_kwargs(kwarg_names, expanded_args_and_kwargs)
        input, num_groups = expanded_args
        N = input.shape[0]
        C = input.shape[1]
        HxW = reduce(operator.mul, input.shape[2:], 1)
        weight, bias, eps = expanded_kwargs['weight'], expanded_kwargs['bias'], expanded_kwargs['eps']
        output, mean, rstd = forward_helper(torch.native_group_norm, (input, weight, bias, N, C, HxW, num_groups, eps), {})
        ctx.input, ctx.num_groups = input, num_groups
        ctx.weight, ctx.eps = weight, eps
        ctx.mean, ctx.rstd = mean, rstd
        if isinstance(bias, ExpandedWeight):
            ctx.bias = bias
        if input.requires_grad and isinstance(weight, ExpandedWeight):
            ctx.weight = weight
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, num_groups = ctx.input, ctx.num_groups
        weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
        mean, rstd = ctx.mean, ctx.rstd

        results: List[Optional[torch.Tensor]] = []
        results.append(None)  # for kwarg names
        results.append(None)  # for op reference

        if input.requires_grad:
            weight_c = unpack_expanded_weight_or_tensor(weight, lambda t: t.contiguous())
            input_c = input.contiguous()
            grad_output_c = grad_output.contiguous() if grad_output is not None else None
            N = input.shape[0]
            C = input.shape[1]
            HxW = 1
            for s in input.shape[2:]:
                HxW *= s
            bw_fn = torch.ops.aten.native_group_norm_backward
            results.append(bw_fn(grad_output_c, input_c,
                                 mean, rstd, weight_c, N, C, HxW, num_groups, (True, False, False))[0])
        else:
            results.append(None)

        # weight and bias don't compute batched gradients; no other arguments are differentiable
        results = results + [None] * 4

        # set grad_sample field for weight and bias with per sample gradients
        if hasattr(ctx, "weight"):
            set_grad_sample_if_exists(weight,
                                      lambda _: torch.einsum("ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output))
        if hasattr(ctx, "bias"):
            set_grad_sample_if_exists(bias, lambda _: torch.einsum("ni...->ni", grad_output))
        return tuple(results)
