Tävlingsprogrammering/Uppgifter/Tärningssumma

Från Wikibooks


Tärningssumma

Tärningen är en klassisk symbol för slump. Slumpen är i teorin rättvis, men det gäller tyvärr inte i praktiken. I synnerhet gäller det inte i denna uppgift, där vi jobbar med orättvisa tärningar.

Du kastar ett antal (1 ≤ N ≤ 100) identiska sexsidiga tärningar exakt en gång var. Du känner till sannolikheten p(i), i = 1, 2, 3, 4, 5, 6 för att med en tärning få etta, tvåa, trea o.s.v. (du kan anta att sannolikheterna adderar upp till exakt ett).

Skriv ett program som beräknar sannolikheten för att summan av tärningarnas utfall blir större än eller lika med ett visst heltalsvärde S, där 1 ≤ S ≤ 600. Svaret ska vara korrekt med minst 6 decimalers noggrannhet.

Delpoäng kommer ges om programmet klarar alla testfall med N ≤ 10.


Körningsexempel 1

Antal tärningar ? 2
Talet S ? 6
p(1) ? 0.2
p(2) ? 0.2
p(3) ? 0.2
p(4) ? 0.2
p(5) ? 0.1
p(6) ? 0.1
Svar: 0.6

Körningsexempel 2

Antal tärningar ? 100
Talet S ? 353
p(1) ? 0.1
p(2) ? 0.2
p(3) ? 0.15
p(4) ? 0.3
p(5) ? 0.15
p(6) ? 0.1
Svar: 0.432678

Lösning[redigera]

Antag att vi har en funktion f(x,y) som returnerar sannolikheten att summan blir minst S då vi kastat x tärningar och nuvarande summa är y. Vi kan då uttrycka f(x,y) rekursivt på följande sätt:

f(kastade, summa) = 
    if kastade == N:
        returnera 1.0 om summa >= S, annars 0.0.
    annars:
        double svar = 0.0
        för 1 <= i <= 6:
            svar += p[i] * f(kastade+1,summa+i)
        return svar

En naiv implementation av algoritmen räcker till delpoängen, men för att klara större indata så använder vi memoisering, dvs vi sparar undan värdet vi beräknat för ett visst x,y i en tabell. Om funktionen anropas igen med samma parametrar så behöver vi inte räkna om svaret, och tidskomplexitet för algoritmen förvandlas från O(6^(antal tärningar)) till O(antal tärningar * maximal summa).

Implementation i C++:

#include <iostream>

#define MAX_DICES 105
#define MAX_SUM 605

using namespace std;

double dp[MAX_DICES][MAX_SUM];
double probs[6];
int N, S;

double getprob(int dice, int sum) {
    if (dice == N) {
        if (sum >= S) return 1.0;
        else return 0.0;
    }
    if (dp[dice][sum] > -0.5) return dp[dice][sum]; // utnyttja att tabellvardet alltid ar mellan 0 och 1 nar vi beraknat det.
    double ans = 0;
    for (int i = 0; i < 6; ++i) {
        ans += probs[i]*getprob(dice+1,sum + i + 1);
    }
    return dp[dice][sum]=ans; // spara undan tabellvardet
}

int main() {
    cin >> N >> S;
    for (int i = 0; i < 6; ++i) cin >> probs[i];
    for (int i = 0; i < MAX_DICES; ++i) for (int j = 0; j < MAX_SUM; ++j) dp[i][j] = -1.0; // -1 i dp-tabellen betyder icke-beraknat varde
    cout << getprob(0,0) << endl;
}