1

You can consider me novice to intermediate at best with Machine Learning.

For the past few months, I've been developing a neural network that learns to play a 3D fighting game by trying to mimic how I play, in Keras, using a Tensorflow backend. The input data consists of both a low-resolution, greyscaled version of the frame at the time, along with some corresponding categorical information. I extract important information from the game using computer vision and represent that using a multi-hot array. E.g. [0, 1, 0, 0, 0, 0, 1, 0, 0, 0] where each index represents some information about that moment, like if the enemy is attacking. This all might be useless information, but I'm providing it just in case.

The labels are the crucial matter. I am providing the network with an array of the buttons that I am pressing at that moment in time. The key thing to note is that this array is multi-hot, as I am often pressing multiple buttons at once.

The unfortunate problem I have had from the start was that my network just doesn't seem to work well. It claims to achieve over 90% accuracy and validation accuracy, and yet only exhibits some form of intelligence when tested. I think this is down to the inherent imbalance of my training labels, as some buttons are just far, far more likely to be pressed than others. But I'm really not sure how to deal with balancing them, as traditional methods like over and undersampling or using class weightings don't work with multi-label classification. I'm basically stuck, and googling isn't really helping.

Any help or advice would be greatly appreciated. If you need more information, please don't hesitate to ask. Thank you!

EDIT: Additional information - The network architecture consists of two branches - one LSTM layer for the categorical data, and one convolutional layer for the image data. These are then concatenated and fed through two Dense layers before finally going through a sigmoid activation. I am using binary-crossentropy as a loss function. This is really the only combination of activation and loss that I am aware of for multi-label classification. In terms of metrics, both accuracy and binary accuracy appear to result in over 90% for validation and training each time, right from epoch 1.

Here is a diagram of the model:

enter image description here

EDIT 2: After doing some thinking, might one potential way around this be to split everything up into to two neural networks? One for approximately half of the classes, the ones that are over-represented, and one for the under-represented, then run them simultaenously after training. Might this be a solution? Also if any better metric could be suggested than binary accuracy, that would be appreciated. Thanks again.

Polyrogue
  • 11
  • 1
  • 4
  • Welcome to Cross Validated! It would help if you added information about your network architecture and variations that you've tried, like choice of activation functions and loss functions. Also, did you look at any other metrics besides accuracy? – AlexK Apr 09 '19 at 23:48
  • Thank you for the comment, and the welcome! Of course, I'll edit the post now with all the information you requested. EDIT: Done. – Polyrogue Apr 10 '19 at 00:24
  • Don't use accuracy as a criterion, and don't oversample: [Are unbalanced datasets problematic, and (how) does oversampling (purport to) help?](https://stats.stackexchange.com/q/357466/1352) and links therein. – Stephan Kolassa Apr 10 '19 at 06:40
  • This site discusses and implements some methods to deal with this issue (including code for other kinds of loss functions): https://github.com/Bupenieks/ImbalancedMLC – AlexK Apr 10 '19 at 07:14
  • Make sure to check how they tackled similar problem in AlphaStar https://deepmind.com/blog/alphastar-mastering-real-time-strategy-game-starcraft-ii/ – Jan Kukacka Apr 10 '19 at 08:14
  • Thank you, @AlexK and Jan Kukacka! I'll have a good look at both of the links you provide. – Polyrogue Apr 10 '19 at 08:23
  • You didn't mention what is the input of your network. Is it expected button id? Is it some kind of other action? What is imbalanced? Input? Target? – Jakub Bartczuk Apr 10 '19 at 09:37
  • 1
    @JakubBartczuk I did mention the input data, in detail. First paragraph. I explain that it is the image data of the frame going into a conv layer, and information about the game (like whether the enemy is attacking, or the direction of their guard) going into the LSTM layer. I also discuss in detail that it is the labels that are imbalanced. The labels are the 14 buttons that can be pressed on a controller (8 of which are discretised versions of joystick movements). Thank you. – Polyrogue Apr 10 '19 at 10:16
  • @Polyrogue May I ask some references backing "using class weightings don't work with multi-label classification"? Regarding the Loss Functions, you could try using other loss functions such as the Focal Loss. The Focal Loss would down weight the easily classified samples, helping to contrast the class imbalance. Another approach would be to use a framework to explain the predictions of your NN, such as Layerwise Relevance Propagation – Hichame Yessou Dec 06 '19 at 23:58

0 Answers0