AT274の技術だったりなかったり

あっとのTECH LOG

競プロのこととか技術のこととか。たまに日常を書きます。

ABC150 E - Change a Little Bit

問題原文

atcoder.jp

問題要旨

長さ  N の、 0と1からなる相異なる数列  S,  T を考える。また、長さ  N の数列  C が与えられる。  S に対して以下の操作を繰り返し行い、  S = T にすることを考える。

  •  S のある項を0なら1に、1なら0に変える。この時コストとして、  C_i × (その時点で  S_i != T_i となっているようなものの数)がかかる。

上記に操作を繰り返し行うことで、 [S = T] にするまでにかかる最小のコストを  f(S, T) と定める。
考えうる全ての  S,  T の組み合わせについて、  f(S, T) を計算し、その総和を  10^{9} + 7 で割ったあまりを求めよ。

  •   1 \leq N \leq 2 × 10^{5}
  •   1 \leq C \leq 10^{9}

解法

 T を 000...000に固定する

まず、  S,  T は相異なると問題文にあるが、  S = T となるようなものはそもそもコストが0なのでまとめて考えてしまって良い。
また結局全パターンについて試されるので、  T = 000...000 として、考えうる全ての  T について答えを求め、 最後に  2^{N} 倍すればいい。

どんな数から  T に揃えていくか

ここから、 各   S についてのみ考える。また、  T = 000...000 としているので、  S_i 番目が1であることは、その部分が  T と異なることを表す。 これを踏まえた上で、 「どんな数から  T に揃えていくか」を考える。
まず、 すでに  S_i = T_i となっているようなものについては明らかにそのままにしておくのがいい。
またよく考えると、  C_i が小さいようなものから揃えていくのが最適であることがわかる。

  S についてのコストを数える → 各数についてどんな係数がつくのかを考える

次に各  S についてのコストを考えていく。
愚直に考えれば、  (C_i × 4) + (C_j × 3) + ... みたいな式になるが、発想を変えて、『各数についてどんな係数がつくのかを考える』 。
つまり 答えを、  \sum_{i=1}^{N} C_i × (全ての  S のパターンについて、 S_i = T_i とする時に S_i != T_i となっているようなものの数の総和) とする。 また結局右の係数は、  C がソート済みであるとするならば、「全ての  S について、  S_i より右側( i 番目含む)に存在する 0 の数」と言い換えられる。

実際の係数のカウント

次は、係数をいかにして数えるかが問題になる。
 S = zzzzz1xxxxxx(6番目が異なる場合) だとして1の右側に何個1が並ぶだろうか?
組み合わせ計算に走りたくなるが、それだと破滅する。ここでも、各数ごとに注目し、「各桁が何回1になりうるか?」を考える。(わかりやすいのでもう桁っていっちゃう)

 zzz...zzz の部分

ここは1より左側にあるから、コストには影響しない。

 1 の部分

1は固定。また、ある箇所を1に固定する場合、とりうるパターン数は  2^{(N - 1)} 個。
だが考えやすいので、  z の数を  l,  x の数を  r とすると、  2^{(l + r)} 個と言える。

 xxx...xxx の部分

 x について、 1になるパターンと0にパターンは実は半分半分になる。
なので、  x の数を  r とすると、 各  x が 1になるパターンは、  \frac{2^{r}}{2} = 2^{r - 1} 個になる。
またそうした  x r 個あるので、 xxx...xxx に現れる1の数は、  2^{r - 1} × r 個になる。
また、  zzz...zzz の部分は自由に決められるので、 全パターンついてでは結局、  2^{l} × (2^{r - 1} × r) 個になる。

答えを求める

以上からある項についてのコストの係数がわかり、ある項についてのコストがわかる。(=答えが求められる。)
具体的には、  2^{(l + r)} と、  2^{l} × (2^{r - 1} × r) の和になるが、これを  2^{l} でくくりだし、  2^{l} × (2^{r - 1} × r + 2^{r}) が各係数になる。

実装

 2^{i} を前計算しててもいいけど、毎回やっても大丈夫。

N = int(input())
C = [0] + sorted(list(map(int, input().split())))
MOD = 10 ** 9 + 7

ans = 0
for i in range(1, N + 1):
    l, r = i - 1, N - i
    ans += C[i] * pow(2, l, MOD) * (pow(2, max(0, r - 1), MOD) * r % MOD + pow(2, r, MOD)) % MOD

print(ans * pow(2, N, MOD) % MOD)

感想

むずかしい、、、けど詰められたはず 係数計算部分でコンビネーションに走ったのが失敗でした。。。

各桁ごとに見る を使ったあとでも意識。。。!!!