I've been using ML-Agents for several months now and have been working on a self-balancing pair of legs. Though, I've had a question that's been itching me since the day I've started: How do I KNOW for a fact that the agents are working together? All I've done is copy and paste the area prefab 9 times. Is that all you have to do to make the agents learn more efficiently? Or is there something else I'm missing? Thanks.
Agent Script >>> (I've not really needed to use any other scripts besides this one. Area and academy have nothing in them.)
using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents.Sensor;
using Random = UnityEngine.Random;
public class BalanceAgent : Agent {
private BalancingArea area;
public GameObject floor;
public GameObject finishBall;
public GameObject waist;
public GameObject wFront; //Used to check balance of waist.
public GameObject wBack; //Used to check balance of waist.
public GameObject hipR;
public GameObject hipL;
public GameObject buttR;
public GameObject buttL;
public GameObject thighR;
public GameObject thighL;
public GameObject legR;
public GameObject legL;
public GameObject footR;
public GameObject footL;
public BehaviorParameters behavePar;
public GameObject sensorFront;
public GameObject sensorBack;
public GameObject sensorLeft;
public GameObject sensorRight;
public float bodyMoveSensitivity = 0.5f;
public GameObject[] bodyParts = new GameObject[11];
HingeJoint[] hingeParts = new HingeJoint[11];
JointLimits[] jntLimParts = new JointLimits[11];
Vector3[] posStart = new Vector3[11];
Vector3[] eulerStart = new Vector3[11];
public Vector3 waistRot;
public float waistVec = 0;
public float buttRVec = 0;
public float buttLVec = 0;
public float thighRVec = 0;
public float thighLVec = 0;
public float legRVec = 0;
public float legLVec = 0;
public float footRVec = 0;
public float footLVec = 0;
public float hipRVec = 0;
public float hipLVec = 0;
public float waistPushXVec = 0;
public float waistPushZVec = 0;
float waistDir = 0;
float buttRDir = 0;
float buttLDir = 0;
float thighRDir = 0;
float thighLDir = 0;
float legRDir = 0;
float legLDir = 0;
float footRDir = 0;
float footLDir = 0;
float hipRDir = 0;
float hipLDir = 0;
float waistPushDirX = 0;
float waistPushDirZ = 0;
public void Start() {
bodyParts = new GameObject[] { waist /*0*/, buttR /*1*/, buttL /*2*/, thighR /*3*/, thighL /*4*/, legR /*5*/, legL /*6*/, footR /*7*/, footL /*8*/, hipR /*9*/, hipL /*10*/};
for (int i = 0; i < bodyParts.Length; i++) {
posStart[i] = bodyParts[i].transform.position;
eulerStart[i] = bodyParts[i].transform.eulerAngles;
if (bodyParts[i].GetComponent<HingeJoint>() != null) {
hingeParts[i] = bodyParts[i].GetComponent<HingeJoint>();
hingeParts[i].limits = jntLimParts[i];
}
}
}
public override void InitializeAgent() {
base.InitializeAgent();
area = GetComponentInParent<BalancingArea>();
}
public override void AgentReset() {
//floor.transform.eulerAngles = new Vector3(Random.Range(-10, 10), 0, Random.Range(-10, 10)); //Floor random rotation
//finishBall.transform.localPosition = new Vector3(Random.Range(-7, 7), .65f, Random.Range(-7, 7)); //Ball random position
jntLimParts[1].max = 0;
jntLimParts[1].min = jntLimParts[1].max - 1;
hingeParts[1].limits = jntLimParts[1];
jntLimParts[2].max = 0;
jntLimParts[2].min = jntLimParts[2].max - 1;
hingeParts[2].limits = jntLimParts[2];
jntLimParts[3].max = -15;
jntLimParts[3].min = jntLimParts[3].max - 1;
hingeParts[3].limits = jntLimParts[3];
jntLimParts[4].max = -15;
jntLimParts[4].min = jntLimParts[4].max - 1;
hingeParts[4].limits = jntLimParts[4];
jntLimParts[5].max = 15;
jntLimParts[5].min = jntLimParts[5].max - 1;
hingeParts[5].limits = jntLimParts[5];
jntLimParts[6].max = 15;
jntLimParts[6].min = jntLimParts[6].max - 1;
hingeParts[6].limits = jntLimParts[6];
jntLimParts[7].max = -15;
jntLimParts[7].min = jntLimParts[7].max - 1;
hingeParts[7].limits = jntLimParts[7];
jntLimParts[8].max = -15;
jntLimParts[8].min = jntLimParts[8].max - 1;
hingeParts[8].limits = jntLimParts[8];
jntLimParts[9].max = 0;
jntLimParts[9].min = jntLimParts[9].max - 1;
hingeParts[9].limits = jntLimParts[9];
jntLimParts[10].max = 0;
jntLimParts[10].min = jntLimParts[10].max - 1;
hingeParts[10].limits = jntLimParts[10];
for (int i = 0; i < bodyParts.Length; i++) {
bodyParts[i].transform.position = posStart[i];
bodyParts[i].transform.eulerAngles = eulerStart[i];
bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
if (bodyParts[i].GetComponent<HingeJoint>() != null) {
hingeParts[i] = bodyParts[i].GetComponent<HingeJoint>();
hingeParts[i].limits = jntLimParts[i];
}
}
//waist.transform.eulerAngles = new Vector3(0, Random.Range(0, 360), 0); //Random player direction
waistRot = waist.transform.eulerAngles;
}
public override void AgentAction(float[] vectorAction) {
waistVec = (int)vectorAction[0];
switch (waistVec) {
case 0:
waistDir = 0;
break;
case 1:
waistDir = bodyMoveSensitivity;
break;
case 2:
waistDir = -bodyMoveSensitivity;
break;
case 3:
waistDir = bodyMoveSensitivity * 2;
break;
case 4:
waistDir = -bodyMoveSensitivity * 2;
break;
case 5:
waistDir = bodyMoveSensitivity * 4;
break;
case 6:
waistDir = -bodyMoveSensitivity * 4;
break;
}
bodyParts[0].transform.Rotate(0, waistDir, 0);
buttRVec = (int)vectorAction[1];
switch (buttRVec) {
case 0:
buttRDir = 0;
break;
case 1:
buttRDir = bodyMoveSensitivity;
break;
case 2:
buttRDir = -bodyMoveSensitivity;
break;
case 3:
buttRDir = bodyMoveSensitivity * 2;
break;
case 4:
buttRDir = -bodyMoveSensitivity * 2;
break;
case 5:
buttRDir = bodyMoveSensitivity * 4;
break;
case 6:
buttRDir = -bodyMoveSensitivity * 4;
break;
}
if (jntLimParts[1].max < 60 && jntLimParts[1].min > -5) {
jntLimParts[1].max += buttRDir;
jntLimParts[1].min = jntLimParts[1].max - 1;
hingeParts[1].limits = jntLimParts[1];
}
else { //If joint is at limit,
if (jntLimParts[1].min <= -5) {
jntLimParts[1].max = -4;
}
else if (jntLimParts[1].max >= 60) {
jntLimParts[1].max = 59;
}
jntLimParts[1].min = jntLimParts[1].max - 1;
}
buttLVec = (int)vectorAction[2];
switch (buttLVec) {
case 0:
buttLDir = 0;
break;
case 1:
buttLDir = bodyMoveSensitivity;
break;
case 2:
buttLDir = -bodyMoveSensitivity;
break;
case 3:
buttLDir = bodyMoveSensitivity * 2;
break;
case 4:
buttLDir = -bodyMoveSensitivity * 2;
break;
case 5:
buttLDir = bodyMoveSensitivity * 4;
break;
case 6:
buttLDir = -bodyMoveSensitivity * 4;
break;
}
if (jntLimParts[2].max < 5 && jntLimParts[2].min > -60) {
jntLimParts[2].max += buttLDir;
jntLimParts[2].min = jntLimParts[2].max - 1;
hingeParts[2].limits = jntLimParts[2];
}
else { //If joint is at limit,
if (jntLimParts[2].min <= -60) {
jntLimParts[2].max = -58;
}
else if (jntLimParts[2].max >= 5) {
jntLimParts[2].max = 4;
}
jntLimParts[2].min = jntLimParts[2].max - 1;
}
thighRVec = (int)vectorAction[3];
switch (thighRVec) {
case 0:
thighRDir = 0;
break;
case 1:
thighRDir = bodyMoveSensitivity;
break;
case 2:
thighRDir = -bodyMoveSensitivity;
break;
case 3:
thighRDir = bodyMoveSensitivity * 2;
break;
case 4:
thighRDir = -bodyMoveSensitivity * 2;
break;
case 5:
thighRDir = bodyMoveSensitivity * 4;
break;
case 6:
thighRDir = -bodyMoveSensitivity * 4;
break;
}
if (jntLimParts[3].max < 80 && jntLimParts[3].min > -80) {
jntLimParts[3].max += thighRDir;
jntLimParts[3].min = jntLimParts[3].max - 1;
hingeParts[3].limits = jntLimParts[3];
}
else { //If joint is at limit,
if (jntLimParts[3].min <= -80) {
jntLimParts[3].max = -78;
}
else if (jntLimParts[3].max >= 80) {
jntLimParts[3].max = 79;
}
jntLimParts[3].min = jntLimParts[3].max - 1;
}
thighLVec = (int)vectorAction[4];
switch (thighLVec) {
case 0:
thighLDir = 0;
break;
case 1:
thighLDir = bodyMoveSensitivity;
break;
case 2:
thighLDir = -bodyMoveSensitivity;
break;
case 3:
thighLDir = bodyMoveSensitivity * 2;
break;
case 4:
thighLDir = -bodyMoveSensitivity * 2;
break;
case 5:
thighLDir = bodyMoveSensitivity * 4;
break;
case 6:
thighLDir = -bodyMoveSensitivity * 4;
break;
}
if (jntLimParts[4].max < 80 && jntLimParts[4].min > -80) {
jntLimParts[4].max += thighLDir;
jntLimParts[4].min = jntLimParts[4].max - 1;
hingeParts[4].limits = jntLimParts[4];
}
else { //If joint is at limit,
if (jntLimParts[4].min <= -80) {
jntLimParts[4].max = -78;
}
else if (jntLimParts[4].max >= 80) {
jntLimParts[4].max = 79;
}
jntLimParts[4].min = jntLimParts[4].max - 1;
}
legRVec = (int)vectorAction[5];
switch (legRVec) {
case 0:
legRDir = 0;
break;
case 1:
legRDir = bodyMoveSensitivity;
break;
case 2:
legRDir = -bodyMoveSensitivity;
break;
case 3:
legRDir = bodyMoveSensitivity * 2;
break;
case 4:
legRDir = -bodyMoveSensitivity * 2;
break;
case 5:
legRDir = bodyMoveSensitivity * 4;
break;
case 6:
legRDir = -bodyMoveSensitivity * 4;
break;
}
if (jntLimParts[5].max < -3 && jntLimParts[5].min > 80) {
jntLimParts[5].max += legRDir;
jntLimParts[5].min = jntLimParts[5].max - 1;
hingeParts[5].limits = jntLimParts[5];
}
else { //If joint is at limit,
if (jntLimParts[5].min <= -3) {
jntLimParts[5].max = -1;
}
else if (jntLimParts[5].max >= 80) {
jntLimParts[5].max = 79;
}
jntLimParts[5].min = jntLimParts[5].max - 1;
}
legLVec = (int)vectorAction[6];
switch (legLVec) {
case 0:
legLDir = 0;
break;
case 1:
legLDir = bodyMoveSensitivity;
break;
case 2:
legLDir = -bodyMoveSensitivity;
break;
case 3:
legLDir = bodyMoveSensitivity * 2;
break;
case 4:
legLDir = -bodyMoveSensitivity * 2;
break;
case 5:
legLDir = bodyMoveSensitivity * 4;
break;
case 6:
legLDir = -bodyMoveSensitivity * 4;
break;
}
if (jntLimParts[6].max < 80 && jntLimParts[6].min > -3) {
jntLimParts[6].max += legLDir;
jntLimParts[6].min = jntLimParts[6].max - 1;
hingeParts[6].limits = jntLimParts[6];
}
else { //If joint is at limit,
if (jntLimParts[6].min <= -3) {
jntLimParts[6].max = -1;
}
else if (jntLimParts[6].max >= 80) {
jntLimParts[6].max = 79;
}
jntLimParts[6].min = jntLimParts[6].max - 1;
}
footRVec = (int)vectorAction[7];
switch (footRVec) {
case 0:
footRDir = 0;
break;
case 1:
footRDir = bodyMoveSensitivity;
break;
case 2:
footRDir = -bodyMoveSensitivity;
break;
case 3:
footRDir = bodyMoveSensitivity * 2;
break;
case 4:
footRDir = -bodyMoveSensitivity * 2;
break;
case 5:
footRDir = bodyMoveSensitivity * 4;
break;
case 6:
footRDir = -bodyMoveSensitivity * 4;
break;
}
if (jntLimParts[7].max < 50 && jntLimParts[7].min > -50) {
jntLimParts[7].max += footRDir;
jntLimParts[7].min = jntLimParts[7].max - 1;
hingeParts[7].limits = jntLimParts[7];
}
else { //If joint is at limit,
if (jntLimParts[7].min <= -50) {
jntLimParts[7].max = -48;
}
else if (jntLimParts[7].max >= 50) {
jntLimParts[7].max = 49;
}
jntLimParts[7].min = jntLimParts[7].max - 1;
}
footLVec = (int)vectorAction[8];
switch (footLVec) {
case 0:
footLDir = 0;
break;
case 1:
footLDir = bodyMoveSensitivity;
break;
case 2:
footLDir = -bodyMoveSensitivity;
break;
case 3:
footLDir = bodyMoveSensitivity * 2;
break;
case 4:
footLDir = -bodyMoveSensitivity * 2;
break;
case 5:
footLDir = bodyMoveSensitivity * 4;
break;
case 6:
footLDir = -bodyMoveSensitivity * 4;
break;
}
if (jntLimParts[8].max < 50 && jntLimParts[8].min > -50) {
jntLimParts[8].max += footLDir;
jntLimParts[8].min = jntLimParts[8].max - 1;
hingeParts[8].limits = jntLimParts[8];
}
else { //If joint is at limit,
if (jntLimParts[8].min <= -50) {
jntLimParts[8].max = -48;
}
else if (jntLimParts[8].max >= 50) {
jntLimParts[8].max = 49;
}
jntLimParts[8].min = jntLimParts[8].max - 1;
}
hipRVec = (int)vectorAction[9];
switch (hipRVec) {
case 0:
hipRDir = 0;
break;
case 1:
hipRDir = bodyMoveSensitivity;
break;
case 2:
hipRDir = -bodyMoveSensitivity;
break;
case 3:
hipRDir = bodyMoveSensitivity * 2;
break;
case 4:
hipRDir = -bodyMoveSensitivity * 2;
break;
case 5:
hipRDir = bodyMoveSensitivity * 4;
break;
case 6:
hipRDir = -bodyMoveSensitivity * 4;
break;
}
if (jntLimParts[9].max < 45 && jntLimParts[9].min > -15) {
jntLimParts[9].max += hipRDir;
jntLimParts[9].min = jntLimParts[9].max - 1;
hingeParts[9].limits = jntLimParts[9];
}
else { //If joint is at limit,
if (jntLimParts[9].min <= -15) {
jntLimParts[9].max = -13;
}
else if (jntLimParts[9].max >= 45) {
jntLimParts[9].max = 44;
}
jntLimParts[9].min = jntLimParts[9].max - 1;
}
hipLVec = (int)vectorAction[10];
switch (hipLVec) {
case 0:
hipLDir = 0;
break;
case 1:
hipLDir = bodyMoveSensitivity;
break;
case 2:
hipLDir = -bodyMoveSensitivity;
break;
case 3:
hipLDir = bodyMoveSensitivity * 2;
break;
case 4:
hipLDir = -bodyMoveSensitivity * 2;
break;
case 5:
hipLDir = bodyMoveSensitivity * 4;
break;
case 6:
hipLDir = -bodyMoveSensitivity * 4;
break;
}
if (jntLimParts[10].max < 15 && jntLimParts[10].min > -45) {
jntLimParts[10].max += hipLDir;
jntLimParts[10].min = jntLimParts[10].max - 1;
hingeParts[10].limits = jntLimParts[10];
}
else { //If joint is at limit,
if (jntLimParts[10].min <= -45) {
jntLimParts[10].max = -43;
}
else if (jntLimParts[10].max >= 15) {
jntLimParts[10].max = 14;
}
jntLimParts[10].min = jntLimParts[10].max - 1;
}
waistPushXVec = (int)vectorAction[11];
switch (waistPushXVec) {
case 0:
waistPushDirX = 0;
break;
case 1:
waistPushDirX = -1;
break;
case 2:
waistPushDirX = 1;
break;
}
waistPushZVec = (int)vectorAction[12];
switch (waistPushZVec) {
case 0:
waistPushDirZ = 0;
break;
case 1:
waistPushDirZ = -1;
break;
case 2:
waistPushDirZ = 1;
break;
}
waist.GetComponent<Rigidbody>().AddForce(waistPushDirX, 0, waistPushDirZ); //Try to help move waist
//waist.transform.eulerAngles = new Vector3(0, waistRot.y, 0);
sensorFront.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y - 90, 0); //Forces sensor to look down constantly.
sensorBack.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y + 90, 0); //Forces sensor to look down constantly.
sensorLeft.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y - 180, 0); //Forces sensor to look down constantly.
sensorRight.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y, 0); //Forces sensor to look down constantly.
//Reward SYSTEM #####################################################################################################################################################################
AddReward(.1f); //Survival reward.
if (Mathf.Abs(finishBall.transform.position.x - waist.transform.position.x) > .25f && Mathf.Abs(finishBall.transform.position.z - waist.transform.position.z) > .25f) { //Maintain waist position to ball
AddReward(-.1f * Mathf.Abs(finishBall.transform.position.x - waist.transform.position.x));
}
if (waist.GetComponent<Rigidbody>().velocity.magnitude >= 20f) { //Maintain waist slow velocity.
AddReward(-.1f);
Done();
}
if (waist.transform.position.y < -2 || waist.transform.position.y > 6) { //Maintain waist height.
AddReward(-.1f * Mathf.Abs(finishBall.transform.position.y - waist.transform.position.y));
Done();
}
if (waist.transform.eulerAngles.y > waistRot.y + 25) { //Maintain waist rotation on Y
AddReward(-.1f * Mathf.Abs(waist.transform.eulerAngles.y - waistRot.y));
Done();
}
if (waist.transform.eulerAngles.y < waistRot.y - 25) { //Maintain waist rotation on Y
AddReward(-.1f * Mathf.Abs(waistRot.y - waist.transform.eulerAngles.y));
Done();
}
if (wFront.transform.position.y < wBack.transform.position.y - 25) { //Maintain waist rotation forward and backwards.
AddReward(-.1f * Mathf.Abs(wBack.transform.position.y - wFront.transform.position.y));
Done();
}
if (wFront.transform.position.y > wBack.transform.position.y + 25) { //Maintain waist rotation forward and backwards.
AddReward(-.1f * Mathf.Abs(wFront.transform.position.y - wBack.transform.position.y));
Done();
}
if (buttR.transform.position.y < buttL.transform.position.y - 25) { //Maintain waist rotation left and right.
AddReward(-.1f * Mathf.Abs(buttL.transform.position.y - buttR.transform.position.y));
Done();
}
if (buttR.transform.position.y > buttL.transform.position.y + 25) { //Maintain waist rotation left and right.
AddReward(-.1f * Mathf.Abs(buttR.transform.position.y - buttL.transform.position.y));
Done();
}
/*
if (waist.transform.position.x > posStart[0].x + 10 || waist.transform.position.x < posStart[0].x - 10 || waist.transform.position.z > posStart[0].z + 10 || waist.transform.position.z < posStart[0].z - 10) { //Maintain waist position.
AddReward(-.01f);
Done();
}
*/
//Reward SYSTEM #####################################################################################################################################################################
}
public override void CollectObservations() {
for (int i = 0; i < bodyParts.Length; i++) {
AddVectorObs(bodyParts[i].transform.position);
AddVectorObs(bodyParts[i].transform.eulerAngles);
AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().velocity);
AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().angularVelocity);
AddVectorObs(jntLimParts[i].max);
AddVectorObs(jntLimParts[i].min);
}
AddVectorObs(wFront.transform.position.y);
AddVectorObs(wFront.transform.eulerAngles);
AddVectorObs(wBack.transform.position.y);
AddVectorObs(wBack.transform.eulerAngles);
AddVectorObs(waistRot); //Waist rotation value after randomization.
AddVectorObs(finishBall.transform.position); //Waist rotation value after randomization.
}
}
I believe yes all you need to do is have multiple instances of the prefab. As long as there are multiple Area
s in the scene, they should be able to coordinate their batches for learning.
If you want to measure how having multiple areas changes things, I would have one area and let it play for some time, and look at a graph of cumulative reward vs. episode number and see how high it gets, then do the same thing with many areas and see how the same graph looks with that.