I am trying to solve the following problem using pytorch: given a six sided die whose average roll is known to be 4.5, what is the maximum entropy distribution for the faces?
(Note: I know a bunch of non-pytorch techniques for solving problems of this sort - my goal here is really to be better understand how to solve constrained optimization problems in general with pytorch. In real life I'm working on a much harder constrained optimization problem involving a neural model implemented in pytorch, and I'm hoping that if I can solve this problem then it will help with the harder problem.)
In principle it should be possible to handle this by looking for critical points of the Lagrangian:
$$L(p) = -\sum_i p_i \log p_i + \lambda\left(\sum_i p_i - 1\right) + \mu\left(\sum_i i p_i - 4.5\right)$$
Here's my attempt to do this with pytorch:
class MaxEntropyDice(torch.nn.Module):
def __init__(self, num_faces=6, mean_constraint=3.5):
super().__init__()
self.num_faces = num_faces
self.mean_constraint = mean_constraint
self.p = torch.nn.Parameter(F.normalize(torch.rand(num_faces), p=1, dim=0))
self.probability_multiplier = torch.nn.Parameter(torch.rand(1))
self.mean_multiplier = torch.nn.Parameter(torch.rand(1))
def forward(self):
entropy = -torch.sum(self.p * torch.log(self.p))
probability_term = self.probability_multiplier * (torch.sum(self.p) - 1)
mean_term = self.mean_multiplier * (
torch.sum(torch.tensor(range(1, self.num_faces + 1)) * self.p) - self.mean_constraint
)
lagrangian = entropy + probability_term + mean_term
return lagrangian
model = MaxEntropyDice(num_faces=6, mean_constraint=4.5)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
for i in range(2000):
loss = model()
optimizer.zero_grad()
loss.backward()
optimizer.step()
This results in the probability distribution [0.1759, 0.0827, 0.0457, 0.1483, 0.2648, 0.2583]
, which is not correct - the true answer is [0.05435, 0.07877, 0.1142, 0.1654, 0.2398, 0.3475]
.
(Also, if I set mean_constraint=3.5
then I don't get the uniform distribution, so that's a bad sign.)
Any ideas on how I can make this work?