一 前向
在上式中,x是代表一个tensor
import torchimport triton
import triton.language as tltry:# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it# should not be added to extras_require in setup.py.import apexHAS_APEX = True
except ModuleNotFoundError:HAS_APEX = False@triton.jit
def _layer_norm_fwd_fused(X, # pointer to the inputY, # pointer to the outputW, # pointer to the weightsB, # pointer to the biasesMean, # pointer to the meanRstd, # pointer to the 1/stdstride, # how much to increase the pointer when moving by 1 rowN, # number of columns in Xeps, # epsilon to avoid division by zeroBLOCK_SIZE: tl.constexpr,
):# Map the program id to the row of X and Y it should compute.row = tl.