前言
截止目前(torch2.5)似乎是一个不支持的功能。
样例
class MyClass(object):def __init__(self, w1,w2,w3):self.regist_buffer('w1', w1)self.regist_buffer('w2', w2)self.regist_buffer('w3', w3)def forward(self,x, i):o = 0for i in range(1,4):w_name = f'w{i}'w = self.get_buffer(w_name)o += w*xreturn omodel = MyClass()
script_model = torch.jit.script(model)
这样会有问题。
在script之后,调用 self.get_buffer()
报错, Unknown type name 'torch.nn.Module'
法2
def forward(self,x, i):o = 0for i in range(1,4):w_name = f'w{i}'w = self._buffers[w_name]# orw = self.__dict__['_buffers'][w_name]o += w*xreturn ow_name = f'w{i}'w = self._buffers[w_name]# orw = self.__dict__['_buffers'][w_name]return w*x
script之后报错
Module 'MyClass' has no attribute '_buffers'
Module 'MyClass' has no attribute '__dict__
不存在self._buffers
和 __dict__
法3
def forward(self,x):o = 0for i in range(1,4):w_name = f'w{i}'w = getattr(self, w_name)o += w*xreturn o
script之后报错
getattr's second argument must be a string literal
getattr只支持静态字面量。
但有时候我们是希望动态获取的。
PS:
torch.compile 暂时也不支持 cuda.stream 相关的操作。