I think this is done because we after one composite function we receive new features that append to the main stack. If we perform BN first whole input features will be balanced. Otherwise only layer output(k new features) will be balanced.
So let’s image input array. In first case we wll get:
[k-unbalanced, k-unbalanced, k-unbalanced, ..] — but balanced in overall.
And in the second case:
[k-balanced, k-balanced, k-balanced, ..] — but not balanced in overall.