861 flops += self.num_heads * N * (self.dim // self.num_heads) * N
862 # x = (attn @ v)
863 flops += self.num_heads * N * N * (self.dim // self.num_heads)
864 # x = self.proj(x) 865 flops += N * self.dim * self.dim
866 return flops
867
859 flops += N * self.dim * 3 * self.dim
860 # attn = (q @ k.transpose(-2, -1))
861 flops += self.num_heads * N * (self.dim // self.num_heads) * N
862 # x = (attn @ v) 863 flops += self.num_heads * N * N * (self.dim // self.num_heads)
864 # x = self.proj(x)
865 flops += N * self.dim * self.dim
857 flops = 0
858 # qkv = self.qkv(x)
859 flops += N * self.dim * 3 * self.dim
860 # attn = (q @ k.transpose(-2, -1)) 861 flops += self.num_heads * N * (self.dim // self.num_heads) * N
862 # x = (attn @ v)
863 flops += self.num_heads * N * N * (self.dim // self.num_heads)
855 def flops(self, N):
856 # calculate flops for 1 window with token length of N
857 flops = 0
858 # qkv = self.qkv(x) 859 flops += N * self.dim * 3 * self.dim
860 # attn = (q @ k.transpose(-2, -1))
861 flops += self.num_heads * N * (self.dim // self.num_heads) * N
678 def forward(self, x, x_size):
679 h, w = x_size
680 b, _, c = x.shape
681 # assert seq_len == h * w, "input feature has wrong size" 682
683 shortcut = x
684 x = self.norm1(x)
It is recommended to remove any commented code in your codebase.
for item in sequence:
# print(item)
do_something(item)
# def old_function():
# '''Older implementation that has been replaced'''
# data = get_data()
# api.post(data)
for item in sequence:
do_something(item)