Add missing .mean

This commit is contained in:
KohakuBlueleaf 2024-03-13 11:47:33 +08:00
parent 3e0146f9bd
commit 9f2ae1cb85

View File

@ -173,7 +173,7 @@ class NetworkModule:
orig_weight = orig_weight.to(updown)
merged_scale1 = updown + orig_weight
dora_merged = (
merged_scale1 / merged_scale1(dim=self.dora_mean_dim, keepdim=True) * self.dora_scale
merged_scale1 / merged_scale1.mean(dim=self.dora_mean_dim, keepdim=True) * self.dora_scale
)
final_updown = dora_merged - orig_weight
return final_updown