I've been trying to train my self-balancing agent to learn to keep his waist above a certain position. But after a while, he just repeatedly keeps trying the same approach/tactic over and over again. His mean-reward and std-of-reward just stay the same. What's the best way to approach this? Thank you.
using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents.Sensor;
public class BalanceAgent : Agent
{
private BalancingArea area;
public GameObject waist;
public GameObject buttR;
public GameObject buttL;
public GameObject thighR;
public GameObject thighL;
public GameObject legR;
public GameObject legL;
public GameObject footR;
public GameObject footL;
public GameObject[] bodyParts = new GameObject[9];
public Vector3[] posStart = new Vector3[9];
public Vector3[] eulerStart = new Vector3[9];
public RayPerceptionSensorComponent3D dirForward;
public void Start() {
bodyParts = new GameObject[] { waist, buttR, buttL, thighR, thighL, legR, legL, footR, footL };
for (int i = 0; i < bodyParts.Length; i++) {
posStart[i] = bodyParts[i].transform.position;
eulerStart[i] = bodyParts[i].transform.eulerAngles;
}
}
public override void InitializeAgent() {
base.InitializeAgent();
area = GetComponentInParent<BalancingArea>();
}
public override void AgentReset() {
for (int i = 0; i < bodyParts.Length; i++) {
bodyParts[i].transform.position = posStart[i];
bodyParts[i].transform.eulerAngles = eulerStart[i];
bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
}
}
public override void AgentAction(float[] vectorAction) {
int buttRDir = 0;
int buttRVec = (int)vectorAction[0];
switch (buttRVec) {
case 3:
buttRDir = 0;
break;
case 1:
buttRDir = -1;
break;
case 2:
buttRDir = 1;
break;
}
buttR.transform.Rotate(0, buttRDir, 0);
int buttLDir = 0;
int buttLVec = (int)vectorAction[1];
switch (buttLVec) {
case 3:
buttLDir = 0;
break;
case 1:
buttLDir = -1;
break;
case 2:
buttLDir = 1;
break;
}
buttL.transform.Rotate(0, buttLDir, 0);
int thighRDir = 0;
int thighRVec = (int)vectorAction[2];
switch (thighRVec) {
case 3:
thighRDir = 0;
break;
case 1:
thighRDir = -1;
break;
case 2:
thighRDir = 1;
break;
}
thighR.transform.Rotate(0, thighRDir, 0);
int thighLDir = 0;
int thighLVec = (int)vectorAction[3];
switch (thighLVec) {
case 3:
thighLDir = 0;
break;
case 1:
thighLDir = -1;
break;
case 2:
thighLDir = 1;
break;
}
thighL.transform.Rotate(0, thighLDir, 0);
int legRDir = 0;
int legRVec = (int)vectorAction[4];
switch (legRVec) {
case 3:
legRDir = 0;
break;
case 1:
legRDir = -1;
break;
case 2:
legRDir = 1;
break;
}
legR.transform.Rotate(0, legRDir, 0);
int legLDir = 0;
int legLVec = (int)vectorAction[5];
switch (legLVec) {
case 3:
legLDir = 0;
break;
case 1:
legLDir = -1;
break;
case 2:
legLDir = 1;
break;
}
legL.transform.Rotate(0, legLDir, 0);
int footRDir = 0;
int footRVec = (int)vectorAction[6];
switch (footRVec) {
case 3:
footRDir = 0;
break;
case 1:
footRDir = -1;
break;
case 2:
footRDir = 1;
break;
}
footR.transform.Rotate(0, footRDir, 0);
int footLDir = 0;
int footLVec = (int)vectorAction[7];
switch (footLVec) {
case 3:
footLDir = 0;
break;
case 1:
footLDir = -1;
break;
case 2:
footLDir = 1;
break;
}
footL.transform.Rotate(0, footLDir, 0);
//buttR = vectorAction[0]; //Right or none
//if (buttR == 2) buttR = -1f; //Left
if (waist.transform.position.y > -1.4f) {
AddReward(.1f);
}
else {
AddReward(-.02f);
}
if (waist.transform.position.y <= -3) {
Done();
print("He fell too far...");
}
}
public override void CollectObservations() {
for(int i = 0; i < bodyParts.Length; i++) {
AddVectorObs(bodyParts[i].transform.position);
AddVectorObs(bodyParts[i].transform.eulerAngles);
AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().velocity);
AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().angularVelocity);
//AddVectorObs(dirForward);
}
}
}