unity-game-enginemachine-learningml-agent

How do I know the agents are working together?


I've been using ML-Agents for several months now and have been working on a self-balancing pair of legs. Though, I've had a question that's been itching me since the day I've started: How do I KNOW for a fact that the agents are working together? All I've done is copy and paste the area prefab 9 times. Is that all you have to do to make the agents learn more efficiently? Or is there something else I'm missing? Thanks.

training

Agent Script >>> (I've not really needed to use any other scripts besides this one. Area and academy have nothing in them.)

using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents.Sensor;
using Random = UnityEngine.Random;

public class BalanceAgent : Agent {

    private BalancingArea area;
    public GameObject floor;
    public GameObject finishBall;
    public GameObject waist;
    public GameObject wFront;           //Used to check balance of waist.
    public GameObject wBack;           //Used to check balance of waist.
    public GameObject hipR;
    public GameObject hipL;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;

    public BehaviorParameters behavePar;

    public GameObject sensorFront;
    public GameObject sensorBack;
    public GameObject sensorLeft;
    public GameObject sensorRight;

    public float bodyMoveSensitivity = 0.5f;

    public GameObject[] bodyParts = new GameObject[11];
    HingeJoint[] hingeParts = new HingeJoint[11];
    JointLimits[] jntLimParts = new JointLimits[11];

    Vector3[] posStart = new Vector3[11];
    Vector3[] eulerStart = new Vector3[11];

    public Vector3 waistRot;

    public float waistVec = 0;
    public float buttRVec = 0;
    public float buttLVec = 0;
    public float thighRVec = 0;
    public float thighLVec = 0;
    public float legRVec = 0;
    public float legLVec = 0;
    public float footRVec = 0;
    public float footLVec = 0;
    public float hipRVec = 0;
    public float hipLVec = 0;
    public float waistPushXVec = 0;
    public float waistPushZVec = 0;

    float waistDir = 0;
    float buttRDir = 0;
    float buttLDir = 0;
    float thighRDir = 0;
    float thighLDir = 0;
    float legRDir = 0;
    float legLDir = 0;
    float footRDir = 0;
    float footLDir = 0;
    float hipRDir = 0;
    float hipLDir = 0;
    float waistPushDirX = 0;
    float waistPushDirZ = 0;

    public void Start() {
        bodyParts = new GameObject[] { waist /*0*/, buttR /*1*/, buttL /*2*/, thighR /*3*/, thighL /*4*/, legR /*5*/, legL /*6*/, footR /*7*/, footL /*8*/, hipR /*9*/, hipL /*10*/};

        for (int i = 0; i < bodyParts.Length; i++) {
            posStart[i] = bodyParts[i].transform.position;
            eulerStart[i] = bodyParts[i].transform.eulerAngles;
            if (bodyParts[i].GetComponent<HingeJoint>() != null) {
                hingeParts[i] = bodyParts[i].GetComponent<HingeJoint>();
                hingeParts[i].limits = jntLimParts[i];
            }
        }
    }

    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();
    }

    public override void AgentReset() {
        //floor.transform.eulerAngles = new Vector3(Random.Range(-10, 10), 0, Random.Range(-10, 10));             //Floor random rotation
        //finishBall.transform.localPosition = new Vector3(Random.Range(-7, 7), .65f, Random.Range(-7, 7));             //Ball random position

        jntLimParts[1].max = 0;
        jntLimParts[1].min = jntLimParts[1].max - 1;
        hingeParts[1].limits = jntLimParts[1];

        jntLimParts[2].max = 0;
        jntLimParts[2].min = jntLimParts[2].max - 1;
        hingeParts[2].limits = jntLimParts[2];

        jntLimParts[3].max = -15;
        jntLimParts[3].min = jntLimParts[3].max - 1;
        hingeParts[3].limits = jntLimParts[3];

        jntLimParts[4].max = -15;
        jntLimParts[4].min = jntLimParts[4].max - 1;
        hingeParts[4].limits = jntLimParts[4];

        jntLimParts[5].max = 15;
        jntLimParts[5].min = jntLimParts[5].max - 1;
        hingeParts[5].limits = jntLimParts[5];

        jntLimParts[6].max = 15;
        jntLimParts[6].min = jntLimParts[6].max - 1;
        hingeParts[6].limits = jntLimParts[6];

        jntLimParts[7].max = -15;
        jntLimParts[7].min = jntLimParts[7].max - 1;
        hingeParts[7].limits = jntLimParts[7];

        jntLimParts[8].max = -15;
        jntLimParts[8].min = jntLimParts[8].max - 1;
        hingeParts[8].limits = jntLimParts[8];

        jntLimParts[9].max = 0;
        jntLimParts[9].min = jntLimParts[9].max - 1;
        hingeParts[9].limits = jntLimParts[9];

        jntLimParts[10].max = 0;
        jntLimParts[10].min = jntLimParts[10].max - 1;
        hingeParts[10].limits = jntLimParts[10];

        for (int i = 0; i < bodyParts.Length; i++) {
            bodyParts[i].transform.position = posStart[i];
            bodyParts[i].transform.eulerAngles = eulerStart[i];
            bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
            bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
            if (bodyParts[i].GetComponent<HingeJoint>() != null) {
                hingeParts[i] = bodyParts[i].GetComponent<HingeJoint>();
                hingeParts[i].limits = jntLimParts[i];
            }
        }

        //waist.transform.eulerAngles = new Vector3(0, Random.Range(0, 360), 0);                //Random player direction
        waistRot = waist.transform.eulerAngles;
    }

    public override void AgentAction(float[] vectorAction) {

        waistVec = (int)vectorAction[0];
        switch (waistVec) {
            case 0:
                waistDir = 0;
                break;
            case 1:
                waistDir = bodyMoveSensitivity;
                break;
            case 2:
                waistDir = -bodyMoveSensitivity;
                break;
            case 3:
                waistDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                waistDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                waistDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                waistDir = -bodyMoveSensitivity * 4;
                break;
        }
        bodyParts[0].transform.Rotate(0, waistDir, 0);

        buttRVec = (int)vectorAction[1];
        switch (buttRVec) {
            case 0:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = bodyMoveSensitivity;
                break;
            case 2:
                buttRDir = -bodyMoveSensitivity;
                break;
            case 3:
                buttRDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                buttRDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                buttRDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                buttRDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[1].max < 60 && jntLimParts[1].min > -5) {
            jntLimParts[1].max += buttRDir;
            jntLimParts[1].min = jntLimParts[1].max - 1;
            hingeParts[1].limits = jntLimParts[1];
        }
        else {              //If joint is at limit,
            if (jntLimParts[1].min <= -5) {
                jntLimParts[1].max = -4;

            }
            else if (jntLimParts[1].max >= 60) {
                jntLimParts[1].max = 59;
            }
            jntLimParts[1].min = jntLimParts[1].max - 1;
        }

        buttLVec = (int)vectorAction[2];
        switch (buttLVec) {
            case 0:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = bodyMoveSensitivity;
                break;
            case 2:
                buttLDir = -bodyMoveSensitivity;
                break;
            case 3:
                buttLDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                buttLDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                buttLDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                buttLDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[2].max < 5 && jntLimParts[2].min > -60) {
            jntLimParts[2].max += buttLDir;
            jntLimParts[2].min = jntLimParts[2].max - 1;
            hingeParts[2].limits = jntLimParts[2];
        }
        else {              //If joint is at limit,
            if (jntLimParts[2].min <= -60) {
                jntLimParts[2].max = -58;

            }
            else if (jntLimParts[2].max >= 5) {
                jntLimParts[2].max = 4;
            }
            jntLimParts[2].min = jntLimParts[2].max - 1;
        }

        thighRVec = (int)vectorAction[3];
        switch (thighRVec) {
            case 0:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = bodyMoveSensitivity;
                break;
            case 2:
                thighRDir = -bodyMoveSensitivity;
                break;
            case 3:
                thighRDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                thighRDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                thighRDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                thighRDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[3].max < 80 && jntLimParts[3].min > -80) {
            jntLimParts[3].max += thighRDir;
            jntLimParts[3].min = jntLimParts[3].max - 1;
            hingeParts[3].limits = jntLimParts[3];
        }
        else {              //If joint is at limit,
            if (jntLimParts[3].min <= -80) {
                jntLimParts[3].max = -78;

            }
            else if (jntLimParts[3].max >= 80) {
                jntLimParts[3].max = 79;
            }
            jntLimParts[3].min = jntLimParts[3].max - 1;
        }

        thighLVec = (int)vectorAction[4];
        switch (thighLVec) {
            case 0:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = bodyMoveSensitivity;
                break;
            case 2:
                thighLDir = -bodyMoveSensitivity;
                break;
            case 3:
                thighLDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                thighLDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                thighLDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                thighLDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[4].max < 80 && jntLimParts[4].min > -80) {
            jntLimParts[4].max += thighLDir;
            jntLimParts[4].min = jntLimParts[4].max - 1;
            hingeParts[4].limits = jntLimParts[4];
        }
        else {              //If joint is at limit,
            if (jntLimParts[4].min <= -80) {
                jntLimParts[4].max = -78;

            }
            else if (jntLimParts[4].max >= 80) {
                jntLimParts[4].max = 79;
            }
            jntLimParts[4].min = jntLimParts[4].max - 1;
        }

        legRVec = (int)vectorAction[5];
        switch (legRVec) {
            case 0:
                legRDir = 0;
                break;
            case 1:
                legRDir = bodyMoveSensitivity;
                break;
            case 2:
                legRDir = -bodyMoveSensitivity;
                break;
            case 3:
                legRDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                legRDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                legRDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                legRDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[5].max < -3 && jntLimParts[5].min > 80) {
            jntLimParts[5].max += legRDir;
            jntLimParts[5].min = jntLimParts[5].max - 1;
            hingeParts[5].limits = jntLimParts[5];
        }
        else {              //If joint is at limit,
            if (jntLimParts[5].min <= -3) {
                jntLimParts[5].max = -1;

            }
            else if (jntLimParts[5].max >= 80) {
                jntLimParts[5].max = 79;
            }
            jntLimParts[5].min = jntLimParts[5].max - 1;
        }

        legLVec = (int)vectorAction[6];
        switch (legLVec) {
            case 0:
                legLDir = 0;
                break;
            case 1:
                legLDir = bodyMoveSensitivity;
                break;
            case 2:
                legLDir = -bodyMoveSensitivity;
                break;
            case 3:
                legLDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                legLDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                legLDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                legLDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[6].max < 80 && jntLimParts[6].min > -3) {
            jntLimParts[6].max += legLDir;
            jntLimParts[6].min = jntLimParts[6].max - 1;
            hingeParts[6].limits = jntLimParts[6];
        }
        else {              //If joint is at limit,
            if (jntLimParts[6].min <= -3) {
                jntLimParts[6].max = -1;

            }
            else if (jntLimParts[6].max >= 80) {
                jntLimParts[6].max = 79;
            }
            jntLimParts[6].min = jntLimParts[6].max - 1;
        }

        footRVec = (int)vectorAction[7];
        switch (footRVec) {
            case 0:
                footRDir = 0;
                break;
            case 1:
                footRDir = bodyMoveSensitivity;
                break;
            case 2:
                footRDir = -bodyMoveSensitivity;
                break;
            case 3:
                footRDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                footRDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                footRDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                footRDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[7].max < 50 && jntLimParts[7].min > -50) {
            jntLimParts[7].max += footRDir;
            jntLimParts[7].min = jntLimParts[7].max - 1;
            hingeParts[7].limits = jntLimParts[7];
        }
        else {              //If joint is at limit,
            if (jntLimParts[7].min <= -50) {
                jntLimParts[7].max = -48;

            }
            else if (jntLimParts[7].max >= 50) {
                jntLimParts[7].max = 49;
            }
            jntLimParts[7].min = jntLimParts[7].max - 1;
        }

        footLVec = (int)vectorAction[8];
        switch (footLVec) {
            case 0:
                footLDir = 0;
                break;
            case 1:
                footLDir = bodyMoveSensitivity;
                break;
            case 2:
                footLDir = -bodyMoveSensitivity;
                break;
            case 3:
                footLDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                footLDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                footLDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                footLDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[8].max < 50 && jntLimParts[8].min > -50) {
            jntLimParts[8].max += footLDir;
            jntLimParts[8].min = jntLimParts[8].max - 1;
            hingeParts[8].limits = jntLimParts[8];
        }
        else {              //If joint is at limit,
            if (jntLimParts[8].min <= -50) {
                jntLimParts[8].max = -48;

            }
            else if (jntLimParts[8].max >= 50) {
                jntLimParts[8].max = 49;
            }
            jntLimParts[8].min = jntLimParts[8].max - 1;
        }


        hipRVec = (int)vectorAction[9];
        switch (hipRVec) {
            case 0:
                hipRDir = 0;
                break;
            case 1:
                hipRDir = bodyMoveSensitivity;
                break;
            case 2:
                hipRDir = -bodyMoveSensitivity;
                break;
            case 3:
                hipRDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                hipRDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                hipRDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                hipRDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[9].max < 45 && jntLimParts[9].min > -15) {
            jntLimParts[9].max += hipRDir;
            jntLimParts[9].min = jntLimParts[9].max - 1;
            hingeParts[9].limits = jntLimParts[9];
        }
        else {              //If joint is at limit,
            if (jntLimParts[9].min <= -15) {
                jntLimParts[9].max = -13;

            }
            else if (jntLimParts[9].max >= 45) {
                jntLimParts[9].max = 44;
            }
            jntLimParts[9].min = jntLimParts[9].max - 1;
        }

        hipLVec = (int)vectorAction[10];
        switch (hipLVec) {
            case 0:
                hipLDir = 0;
                break;
            case 1:
                hipLDir = bodyMoveSensitivity;
                break;
            case 2:
                hipLDir = -bodyMoveSensitivity;
                break;
            case 3:
                hipLDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                hipLDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                hipLDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                hipLDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[10].max < 15 && jntLimParts[10].min > -45) {
            jntLimParts[10].max += hipLDir;
            jntLimParts[10].min = jntLimParts[10].max - 1;
            hingeParts[10].limits = jntLimParts[10];
        }
        else {              //If joint is at limit,
            if (jntLimParts[10].min <= -45) {
                jntLimParts[10].max = -43;

            }
            else if (jntLimParts[10].max >= 15) {
                jntLimParts[10].max = 14;
            }
            jntLimParts[10].min = jntLimParts[10].max - 1;
        }

        waistPushXVec = (int)vectorAction[11];
        switch (waistPushXVec) {
            case 0:
                waistPushDirX = 0;
                break;
            case 1:
                waistPushDirX = -1;
                break;
            case 2:
                waistPushDirX = 1;
                break;
        }
        waistPushZVec = (int)vectorAction[12];
        switch (waistPushZVec) {
            case 0:
                waistPushDirZ = 0;
                break;
            case 1:
                waistPushDirZ = -1;
                break;
            case 2:
                waistPushDirZ = 1;
                break;
        }
        waist.GetComponent<Rigidbody>().AddForce(waistPushDirX, 0, waistPushDirZ);              //Try to help move waist

        //waist.transform.eulerAngles = new Vector3(0, waistRot.y, 0);


        sensorFront.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y - 90, 0);                //Forces sensor to look down constantly.
        sensorBack.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y + 90, 0);                //Forces sensor to look down constantly.
        sensorLeft.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y - 180, 0);                //Forces sensor to look down constantly.
        sensorRight.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y, 0);                //Forces sensor to look down constantly.



        //Reward SYSTEM #####################################################################################################################################################################
        AddReward(.1f);             //Survival reward.

        if (Mathf.Abs(finishBall.transform.position.x - waist.transform.position.x) > .25f && Mathf.Abs(finishBall.transform.position.z - waist.transform.position.z) > .25f) {               //Maintain waist position to ball
            AddReward(-.1f * Mathf.Abs(finishBall.transform.position.x - waist.transform.position.x));
        }

        if (waist.GetComponent<Rigidbody>().velocity.magnitude >= 20f) {               //Maintain waist slow velocity.
            AddReward(-.1f);
            Done();
        }

        if (waist.transform.position.y < -2 || waist.transform.position.y > 6) {               //Maintain waist height.
            AddReward(-.1f * Mathf.Abs(finishBall.transform.position.y - waist.transform.position.y));
            Done();
        }

        if (waist.transform.eulerAngles.y > waistRot.y + 25) {                //Maintain waist rotation on Y
            AddReward(-.1f * Mathf.Abs(waist.transform.eulerAngles.y - waistRot.y));
            Done();
        }
        if (waist.transform.eulerAngles.y < waistRot.y - 25) {                //Maintain waist rotation on Y
            AddReward(-.1f * Mathf.Abs(waistRot.y - waist.transform.eulerAngles.y));
            Done();
        }

        if (wFront.transform.position.y < wBack.transform.position.y - 25) {                //Maintain waist rotation forward and backwards.
            AddReward(-.1f * Mathf.Abs(wBack.transform.position.y - wFront.transform.position.y));
            Done();
        }
        if (wFront.transform.position.y > wBack.transform.position.y + 25) {                //Maintain waist rotation forward and backwards.
            AddReward(-.1f * Mathf.Abs(wFront.transform.position.y - wBack.transform.position.y));
            Done();
        }

        if (buttR.transform.position.y < buttL.transform.position.y - 25) {                //Maintain waist rotation left and right.
            AddReward(-.1f * Mathf.Abs(buttL.transform.position.y - buttR.transform.position.y));
            Done();
        }
        if (buttR.transform.position.y > buttL.transform.position.y + 25) {                //Maintain waist rotation left and right.
            AddReward(-.1f * Mathf.Abs(buttR.transform.position.y - buttL.transform.position.y));
            Done();
        }

        /*
        if (waist.transform.position.x > posStart[0].x + 10 || waist.transform.position.x < posStart[0].x - 10 || waist.transform.position.z > posStart[0].z + 10 || waist.transform.position.z < posStart[0].z - 10) {              //Maintain waist position.
            AddReward(-.01f);
            Done();
        }
        */
        //Reward SYSTEM #####################################################################################################################################################################
    }

    public override void CollectObservations() {

        for (int i = 0; i < bodyParts.Length; i++) {
            AddVectorObs(bodyParts[i].transform.position);
            AddVectorObs(bodyParts[i].transform.eulerAngles);
            AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().velocity);
            AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().angularVelocity);
            AddVectorObs(jntLimParts[i].max);
            AddVectorObs(jntLimParts[i].min);
        }

        AddVectorObs(wFront.transform.position.y);
        AddVectorObs(wFront.transform.eulerAngles);
        AddVectorObs(wBack.transform.position.y);
        AddVectorObs(wBack.transform.eulerAngles);
        AddVectorObs(waistRot);             //Waist rotation value after randomization.
        AddVectorObs(finishBall.transform.position);             //Waist rotation value after randomization.
    }
}

Solution

  • I believe yes all you need to do is have multiple instances of the prefab. As long as there are multiple Areas in the scene, they should be able to coordinate their batches for learning.

    If you want to measure how having multiple areas changes things, I would have one area and let it play for some time, and look at a graph of cumulative reward vs. episode number and see how high it gets, then do the same thing with many areas and see how the same graph looks with that.