1

I am trying to solve the following problem using pytorch: given a six sided die whose average roll is known to be 4.5, what is the maximum entropy distribution for the faces?

(Note: I know a bunch of non-pytorch techniques for solving problems of this sort - my goal here is really to be better understand how to solve constrained optimization problems in general with pytorch. In real life I'm working on a much harder constrained optimization problem involving a neural model implemented in pytorch, and I'm hoping that if I can solve this problem then it will help with the harder problem.)

In principle it should be possible to handle this by looking for critical points of the Lagrangian:

$$L(p) = -\sum_i p_i \log p_i + \lambda\left(\sum_i p_i - 1\right) + \mu\left(\sum_i i p_i - 4.5\right)$$

Here's my attempt to do this with pytorch:

class MaxEntropyDice(torch.nn.Module):
    def __init__(self, num_faces=6, mean_constraint=3.5):
        super().__init__()
        self.num_faces = num_faces
        self.mean_constraint = mean_constraint
        self.p = torch.nn.Parameter(F.normalize(torch.rand(num_faces), p=1, dim=0))
        self.probability_multiplier = torch.nn.Parameter(torch.rand(1))
        self.mean_multiplier = torch.nn.Parameter(torch.rand(1))
    
    def forward(self):
        entropy = -torch.sum(self.p * torch.log(self.p))
        probability_term = self.probability_multiplier * (torch.sum(self.p) - 1)
        mean_term = self.mean_multiplier * (
            torch.sum(torch.tensor(range(1, self.num_faces + 1)) * self.p) - self.mean_constraint
        )
        lagrangian = entropy + probability_term + mean_term
        return lagrangian

model = MaxEntropyDice(num_faces=6, mean_constraint=4.5)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)

for i in range(2000):
    loss = model()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

This results in the probability distribution [0.1759, 0.0827, 0.0457, 0.1483, 0.2648, 0.2583], which is not correct - the true answer is [0.05435, 0.07877, 0.1142, 0.1654, 0.2398, 0.3475]. (Also, if I set mean_constraint=3.5 then I don't get the uniform distribution, so that's a bad sign.)

Any ideas on how I can make this work?

Paul Siegel
  • 221
  • 1
  • 7

0 Answers0