🍐☘️🍏πŸ₯¬πŸ₯¦πŸŒ³πŸŒ²
μΉ΄ν…Œκ³ λ¦¬
μž‘μ„±μΌ
2023. 8. 1. 10:50
μž‘μ„±μž
Λ—Λ‹ ΰ­¨ΰ­§ ΛŠΛ—
using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;   
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;

public class DefendInDetail : Agent
{
    Rigidbody rBody;
    void Start()
    {
        rBody = GetComponent<Rigidbody>();
    }

    public Collider Target;
    
    {
        //μƒˆλ‘œμš΄ μ• ν”Όμ†Œλ“œ μ‹œμž‘μ‹œ, λ‹€μ‹œ μ—μ΄μ „νŠΈμ˜ ν¬μ§€μ…˜μ˜ μ΄ˆκΈ°ν™”
        if (this.transform.localPosition.y < 0) //μ—μ΄μ „νŠΈκ°€ floor μ•„λž˜λ‘œ λ–¨μ–΄μ§„ 경우 μΆ”κ°€ μ΄ˆκΈ°ν™”
        {
            this.rBody.angularVelocity = Vector3.zero;
            this.rBody.velocity = Vector3.zero;
            this.transform.localPosition = new Vector3(0, 0.5f, 0);
        }

        //νƒ€κ²Ÿμ˜ μœ„μΉ˜λŠ” μ—ν”Όμ†Œλ“œ μ‹œμž‘μ‹œ λžœλ€ν•˜κ²Œ λ³€κ²½λœλ‹€.
        Target.transform.localPosition = new Vector3(Random.value * 8 - 4, 0.5f, Random.value * 8 - 4);
    }

    /// <summary>
    /// κ°•ν™”ν•™μŠ΅ ν”„λ‘œκ·Έλž¨μ—κ²Œ 관츑정보λ₯Ό 전달
    /// </summary>
    /// <param name="sensor"></param>
    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(Target.transform.localPosition); //νƒ€κ²Ÿμ˜ ν¬μ§€μ…˜ 전달
        sensor.AddObservation(this.transform.localPosition); //μ—μ΄μ „νŠΈμ˜ ν¬μ§€μ…˜ 전달

        sensor.AddObservation(rBody.velocity.x); //ν˜„μž¬ μ—μ΄μ „νŠΈμ˜ μ΄λ™λŸ‰.x 전달
        sensor.AddObservation(rBody.velocity.z); //ν˜„μž¬ μ—μ΄μ „νŠΈμ˜ μ΄λ™λŸ‰.z 전달
    }

    /// <summary>
    /// κ°•ν™”ν•™μŠ΅μ„ μœ„ν•œ, κ°•ν™”ν•™μŠ΅μ„ ν†΅ν•œ 행동이 κ²°μ •λ˜λŠ” κ³³
    /// </summary>
    public float forceMultiplier = 10;
    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        //ν•™μŠ΅μ„ μœ„ν•œ, ν•™μŠ΅λœ 정보λ₯Ό ν•΄μ„ν•˜μ—¬ 이동을 μ‹œν‚¨λ‹€.

        // Actions, size = 2
        Vector3 controlSignal = Vector3.zero;
        controlSignal.x = actionBuffers.ContinuousActions[0];
        controlSignal.z = actionBuffers.ContinuousActions[1];
        rBody.AddForce(controlSignal * forceMultiplier);

        // Rewards
        float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.transform.localPosition);

        //νƒ€κ²Ÿμ„ μ°»μ„μ‹œ λ¦¬μ›Œλ“œμ μˆ˜λ₯Ό μ£Όκ³ , μ—ν”Όμ†Œλ“œλ₯Ό μ’…λ£Œμ‹œν‚¨λ‹€.
        // Reached target
        if (distanceToTarget < 1.42f)
        {
            SetReward(1.0f);
            EndEpisode();
        }

        //판 μ•„λž˜λ‘œ λ–¨μ–΄μ§€λ©΄ ν•™μŠ΅μ΄ μ’…λ£Œλœλ‹€.
        // Fell off platform
        else if (this.transform.localPosition.y < 0)
        {
            EndEpisode();
        }
    }

    /// <summary>
    /// ν•΄λ‹Ή ν•¨μˆ˜λŠ” μ§μ ‘μ‘°μž‘ ν˜Ήμ€ κ·œμΉ™μ„±μžˆλŠ” μ½”λ”©μœΌλ‘œ μ‘°μž‘μ‹œν‚€κΈ° μœ„ν•œ ν•¨μˆ˜
    /// </summary>
    /// <param name="actionsOut"></param>
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var continuousActionsOut = actionsOut.ContinuousActions;
        continuousActionsOut[0] = Input.GetAxis("Horizontal");
        continuousActionsOut[1] = Input.GetAxis("Vertical");
    }
}