ml-agent

Unity's ML Agents doesn't call OnActionRecieved() (using Behavioural Cloning only)


I'm creating an Agent script, to park a car for me using exclusively Behavioural Cloning (hence no DecisionRequester(), Reward logic, etc. ). After recording a demonstration, the demo file says it has recorded only 1 step and 2 episodes despite me recording a lot more. Debug.Log shows a new episode does begin after I either collide with an obstacle of park, yet it still only records 2 episodes, and Debug.Log in the OnActionRecieved() method never shows up, therefore OnActionRecieved never gets called and I don't know why. This question has been asked on here before, but it either wasn't answered, or the answer wasn't applicable to my specific scenario. Heres the code for the Agent script:

using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using UnityEngine;
using Unity.MLAgents.Actuators;

public class CarParkingAgent : Agent
{
    public Transform targetParkingSpot;
    public CarController carController;
    public float maxRaycastDistance = 10f;
    private Rigidbody rb;
    public float maxSpeed = 10f;
    public float maxAngularVelocity = 10f;
    private bool isCarParked = false;
    public Vector3 tensor;
    private bool lastFrameParkedStatus = false;
    private int currentStep = 0;
    private bool useHeuristic = true; // after ur done with BC just delete this whole logic

void Start()
{
    rb = GetComponent<Rigidbody>();
    rb.inertiaTensor = tensor;
    rb.inertiaTensor = new Vector3(1829.532f, 1974.514f, 391.8728f);

    Debug.Log("Initial Car Position: " + transform.localPosition);
}

private void Update()
{
    // Check if the car is parked
    if (!isCarParked)
    {
        Debug.Log("car aint parked");
        isCarParked = IsCarOnParkingSpot();
    }
}
private void FixedUpdate()
{
    //idk what this does but if i remove it my entire IsCarOnParkingSpot() stops working
    if (!lastFrameParkedStatus)
    {
        isCarParked = IsCarOnParkingSpot();
    }
    lastFrameParkedStatus = isCarParked;
}

private void Heuristic()
{
    if (!useHeuristic)
        return;

    // Calculate target direction (normalized)
    Vector3 toTarget = targetParkingSpot.localPosition - transform.localPosition;
    Vector3 normalizedToTarget = toTarget.normalized;

    // Apply heuristic actions based on target direction
    float steering = Vector3.Dot(normalizedToTarget, transform.right);
    float accelerate = Mathf.Clamp01(Vector3.Dot(rb.velocity, transform.forward));
    float brake = 0f;

    // Apply actions
    carController.HandleMotor(accelerate, brake);
    carController.HandleSteering(steering);
}

public override void CollectObservations(VectorSensor sensor)
{
    
    sensor.AddObservation(transform.localPosition);
    sensor.AddObservation(transform.rotation);
    sensor.AddObservation(carController.GetVelocity());

    Vector3 toTarget = targetParkingSpot.localPosition - transform.localPosition;
    Vector3 normalizedToTarget = toTarget.normalized;
    sensor.AddObservation(toTarget.magnitude);
    sensor.AddObservation(normalizedToTarget);
    sensor.AddObservation(rb.velocity.magnitude / maxSpeed);
    sensor.AddObservation(rb.angularVelocity.magnitude / maxAngularVelocity);

    sensor.AddObservation(targetParkingSpot.localPosition);
    sensor.AddObservation(targetParkingSpot.rotation);

    Vector3[] raycastDirections = { transform.forward, transform.right, -transform.right, transform.forward + transform.right, transform.forward - transform.right };

    foreach (Vector3 direction in raycastDirections)
    {
        if (Physics.Raycast(transform.localPosition, direction, out RaycastHit hit, maxRaycastDistance))
        {
            float normalizedDistance = hit.distance / maxRaycastDistance;
            sensor.AddObservation(normalizedDistance);
        }
        else
        {
            sensor.AddObservation(-1f);
        }
    }
}

public override void OnActionReceived(ActionBuffers actions)
{
    Debug.Log("OnActionReceived called, current step = " + currentStep);
    currentStep++;
    float accelerate = actions.ContinuousActions[0];
    float brake = actions.ContinuousActions[1];
    float steering = actions.ContinuousActions[2];

    if (useHeuristic)
    {
        Heuristic();
    }
    else
    {
        // Use the agent's policy to handle actions
        carController.HandleMotor(accelerate, brake);
        carController.HandleSteering(steering);
    }
}

public override void OnEpisodeBegin()
{
    currentStep = 0;
    Debug.Log("new episode begun");
    ResetCarPosition();
    isCarParked = false;
    lastFrameParkedStatus = false;
}

public void OnTriggerEnter(Collider collision)
{
    if (collision.gameObject.CompareTag("Obstacle"))
    {
        ResetCarPosition();
        EndEpisode();
        isCarParked = false;
    }
    else if (collision.gameObject.CompareTag("ParkingSpot") && !isCarParked)
    {
        IsCarOnParkingSpot();
    }
}

public bool IsCarOnParkingSpot()
{
    Vector3 toParkingSpot = targetParkingSpot.position - transform.position;
    //Debug.Log("Distance from parking spot: " + toParkingSpot.magnitude);
    if (toParkingSpot.magnitude < 1.5f) 
    { 
        isCarParked = true;
        ResetCarPosition();
        EndEpisode();
        return true;
    }
    isCarParked = false;

    return false;
}

private void ResetCarPosition()
{
    rb.velocity = Vector3.zero;
    rb.angularVelocity = Vector3.zero;

    Vector3 localPosition = new Vector3(
        UnityEngine.Random.Range(-7f, -4f),
        0.1f,
        UnityEngine.Random.Range(-7f, 2f));

    //Debug.Log("Calculated Local Position: " + localPosition);

    transform.localPosition = localPosition;
    transform.localRotation = Quaternion.Euler(Vector3.zero);
    
}

}

If anyone has an idea for why OnActionRecieved() isn't called it'd be much appreciated. Thankyou in advance.


Solution

  • Ok I figured out the issue, adding an Awake() method solved the issue. I assume an Initialise() method should work too but it didn't in my case. Use the Awake() instead of the Start() method.

    void Awake()
    {
        Debug.Log("Awake called");
        rb = GetComponent<Rigidbody>();
        rb.inertiaTensor = tensor;
        rb.inertiaTensor = new Vector3(1829.532f, 1974.514f, 391.8728f);
        useHeuristic = true;
    }