Prevent uncessary bias backup

This commit is contained in:
huchenlei 2024-05-16 11:39:01 -04:00
parent ddb28b33a3
commit 51b13a8c54

View File

@ -378,7 +378,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None) bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None: if bias_backup is None and wanted_names != ():
if current_names != ():
raise RuntimeError("no backup bias found and current bias are not unchanged")
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True) bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
elif getattr(self, 'bias', None) is not None: elif getattr(self, 'bias', None) is not None: