AT274の技術だったりなかったり

あっとのTECH LOG

競プロのこととか技術のこととか。たまに日常を書きます。

ABC154 E - Almost Everywhere Zero

問題原文

atcoder.jp

問題要旨

1以上  N 以下の整数であって、10進法で表した時に、0でない数字がちょうど  K 個あるようなものの個数を求めよ。

  •   1 \leq N \leq 10^{100}
  •   1 \leq K \leq 3

解法1:がんばる

 K が高々3通りしか無い and 制約が小さい ので、頑張ればできます。 実装を見た方がいいと思うのでそちらで。

実装

数をつくるところで下手に mapとかを使うと計算量増えてTLEするので気をつけましょう。

N = int(input())
N_digit = len(str(N))
K = int(input())
ans = 0
factorial = [1, 1]
for i in range(2, 101):
    factorial.append(factorial[-1] * i)
 
def nCr(n, r):
    if n - r < 0:
        return 0
    return factorial[n] // (factorial[r] * factorial[n - r])
 
if K == 1:
    ans += 9 * (nCr(N_digit - 1, 1))  # 1桁目を0に固定、2桁目以降は自由
    ans += int(str(N)[0])  # 1桁目を考える。2桁目以降は0
 
elif K == 2:
    ans += (9 ** 2) * (nCr(N_digit - 1, 2))  # 1桁目を0に固定、2桁目以降から自由に2つ
    ans += (int(str(N)[0]) - 1) * 9 * (nCr(N_digit - 1, 1))  # 1桁目を0でなく、かつN以下が確定するように固定
    for d2_position in range(1, N_digit):  # 2つめの0でない数をどこに置くか
        for d2_value in range(1, 10):  # 2つめの0でない数を何にするか
            # 数をつくる
            S = ['0'] * N_digit
            S[0] = str(N)[0]
            S[d2_position] = d2_value
            S = int(''.join(map(str, S)))
 
            ans += (S <= N)
 
elif K == 3:
    ans += (9 ** 3) * (nCr(N_digit - 1, 3))  # 1桁目を0に固定、2桁目以降から自由に3つ
    ans += (int(str(N)[0]) - 1) * (9 ** 2) * (nCr(N_digit - 1, 2))  # 1桁目を0でなく、かつN以下が確定するように固定
    for d2_position in range(1, N_digit):  # 2つめの0でない数をどこに置くか
        for d3_position in range(d2_position + 1, N_digit):  # 3つめの0でない数をどこに置くか
            for d2_value in range(1, 10):  # 2つめの0でない数を何にするか
                for d3_value in range(1, 10):  # 3つめの0でない数を何にするか
                    # 数を作る
                    S = ['0'] * N_digit
                    S[0] = str(N)[0]
                    S[d2_position] = str(d2_value)
                    S[d3_position] = str(d3_value)
                    S = int(''.join(S))
 
                    ans += (S <= N)
 
print(ans)

解法2:桁dp

「どこまで見たか?」「0でない数を何個使ったか?」「 N 以下であることがすでに確定しているか?」の3つの情報を持って桁dpができる。
これも実装を見た方がよさそうと思うのでそちらで。。。

実装

the ベタ書き桁dpという感じ。
解説放送とかだともっと楽にやってそうだけど、dp苦手なのでこういうのしかかけない。。。 3次元にしてもいいのだけど、 N 以下であることが確定しているかどうかでdpテーブルを分けた方が個人的には楽ですね。

N = input()
N_digit = len(N)
K = int(input())

# dp[i][k] := i番目みて、0でない数がk個出現しているときの通り数
# dp0 := N以下であることが未確定, dp1 := N以下であることが確定
dp0 = [[0] * (K + 1) for i in range(N_digit + 1)]
dp1 = [[0] * (K + 1) for i in range(N_digit + 1)]
dp0[0][0] = 1

for i, n in enumerate(N):
    for k in range(K + 1):
        if n == '0':
            dp0[i + 1][k] += dp0[i][k]
            dp1[i + 1][k] += dp1[i][k]
            if k < K:
                dp1[i + 1][k + 1] += 9 * dp1[i][k]

        else:
            if k < K:
                dp0[i + 1][k + 1] += dp0[i][k]  # N以下が確定しないようにするので、nしか選べない
                dp1[i + 1][k + 1] += (int(n) - 1) * dp0[i][k]  # N以下を確定させて0で無いものの使用数を増やす
                dp1[i + 1][k + 1] += 9 * dp1[i][k]  # N以下が確定しているうちから、次に1 ~ 9の9通りの選択

            dp1[i + 1][k] += dp0[i][k]  # N以下未確定状態から0を使う
            dp1[i + 1][k] += dp1[i][k]  # N以下確定状態から0を使う

print(dp0[-1][K] + dp1[-1][K])

感想

1つ目の解法は結構好き。桁dpは苦手。
いや、両方パパッとかけないとダメなんですけどね。。。

数を実際に構築する際にカッコつけてmapとか使うと計算量増えることを忘れてました。(きをつけよう)