本篇想分享下看例子中的源码分析,其实也就是一些我理解之后的注释,一些思路,希望对你有帮助。
这个例子主要是训练方块去左右移动,获得最大奖励,左边奖励小,右边大,于是最后会让方块就往右边走,是一格格走的。 主要源码分析:BasicAgent.cs
using UnityEngine; using MLAgents; public class BasicAgent : Agent { /// <summary> /// 获取环境 /// </summary> [Header("Specific to Basic")] private BasicAcademy academy; /// <summary> /// 请求决策的时间间隔 /// </summary> public float timeBetweenDecisionsAtInference; /// <summary> /// 累计间隔时间 /// </summary> private float timeSinceDecision; /// <summary> /// 起始位置 /// </summary> int position; // <summary> /// 小目标位置 /// </summary> int smallGoalPosition; // <summary> /// 大目标位置 /// </summary> int largeGoalPosition; /// <summary> /// 2个目标物体 /// </summary> public GameObject largeGoal; public GameObject smallGoal; /// <summary> /// 限制最大最小位置,防止跑出去 /// </summary> int minPosition; int maxPosition; public override void InitializeAgent() { academy = FindObjectOfType(typeof(BasicAcademy)) as BasicAcademy; } /// <summary> /// 用了one-host编码 即一个20位的列表 第postion个位置为1, 其他都是0。 /// 举个例子,简单点5位的one-host编码 比如position=3,即[0,0,0,1,0] /// 这样做比较简单,因为就是一格格移动的,可以记录当前在哪一格,跟飞行棋走格子一样 /// </summary> public override void CollectObservations() { AddVectorObs(position, 20); } /// <summary> /// 因为是走格子的,所以是获得离散的输入就可以 /// </summary> /// <param name="vectorAction"></param> /// <param name="textAction"></param> public override void AgentAction(float[] vectorAction, string textAction) { //获取离散的值。一般就是从0开始的,在brain面板里的Branch n Szie里填的, //比如这个n是3,那就是0 1 2,三个值 var movement = (int)vectorAction[0]; int direction = 0; //左右移动 左边-1 右边1 switch (movement) { case 1: direction = -1; break; case 2: direction = 1; break; } //计算位置,限定位置最大最小范围 position += direction; if (position < minPosition) { position = minPosition; } if (position > maxPosition) { position = maxPosition; } gameObject.transform.position = new Vector3(position - 10f, 0f, 0f); //每次行动后给予惩罚,为了让他达到任何目标 AddReward(-0.01f); if (position == smallGoalPosition) { Done(); AddReward(0.1f); } //大目标奖励多 if (position == largeGoalPosition) { Done(); AddReward(1f); } } /// <summary> /// 初始化第一次和每次迭代结束后重新设置位置 /// </summary> public override void AgentReset() { position = 10; minPosition = 0; maxPosition = 20; smallGoalPosition = 7; largeGoalPosition = 17; smallGoal.transform.position = new Vector3(smallGoalPosition - 10f, 0f, 0f); largeGoal.transform.position = new Vector3(largeGoalPosition - 10f, 0f, 0f); } public override void AgentOnDone() { } public void FixedUpdate() { WaitTimeInference(); } /// <summary> /// 固定时间请求决策 /// </summary> private void WaitTimeInference() { if (!academy.GetIsInference()) { RequestDecision(); } else { if (timeSinceDecision >= timeBetweenDecisionsAtInference) { timeSinceDecision = 0f; RequestDecision(); } else { timeSinceDecision += Time.fixedDeltaTime; } } } }主要是训练平台让小球不掉下去,需要同时关注小球的速度,位置,平台的角度。 主要源码分析:Ball3DAgent.cs
using UnityEngine; using MLAgents; public class Ball3DAgent : Agent { [Header("Specific to Ball3D")] public GameObject ball; private Rigidbody ballRb; /// <summary> /// 初始化代理,获得平台刚体组件,其实是在Agent的OnEnable调用 /// </summary> public override void InitializeAgent() { ballRb = ball.GetComponent<Rigidbody>(); } /// <summary> /// 获取观察环境,考虑平台角度,保持平衡,考虑相对位置,小球的速度,衡量是否来得及调整 /// </summary> public override void CollectObservations() { //平台的旋转角度 AddVectorObs(gameObject.transform.rotation.z); AddVectorObs(gameObject.transform.rotation.x); //球和平台的相对位置,用世界坐标和相对父类坐标都可以,差一样的 AddVectorObs(ball.transform.position - gameObject.transform.position); //小球的速度 AddVectorObs(ballRb.velocity); } /// <summary> /// 决策后采取的动作 /// </summary> /// <param name="vectorAction"></param> /// <param name="textAction"></param> public override void AgentAction(float[] vectorAction, string textAction) { //如果参数是连续的,获取Z,X的值,根据情况旋转角度,保持平衡 if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous) { var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f); var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f); if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) || (gameObject.transform.rotation.z > -0.25f && actionZ < 0f)) { gameObject.transform.Rotate(new Vector3(0, 0, 1), actionZ); } if ((gameObject.transform.rotation.x < 0.25f && actionX > 0f) || (gameObject.transform.rotation.x > -0.25f && actionX < 0f)) { gameObject.transform.Rotate(new Vector3(1, 0, 0), actionX); } } //如果球在平台下了,或者跑出平台外了,给予惩罚,否则就奖励 if ((ball.transform.position.y - gameObject.transform.position.y) < -2f || Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f || Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f) { Done(); SetReward(-1f); } else { SetReward(0.1f); } } //一次迭代后重置数据 public override void AgentReset() { gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f); gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f)); gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f)); ballRb.velocity = new Vector3(0f, 0f, 0f); ball.transform.position = new Vector3(Random.Range(-1.5f, 1.5f), 4f, Random.Range(-1.5f, 1.5f)) + gameObject.transform.position; } }在四周有墙的地方,让方块去找绿色为目标,而避开红色的陷阱,也是一格格走的,而且是视觉学习。 主要源码分析:GridAcademy.cs
using System.Collections.Generic; using UnityEngine; using System.Linq; using MLAgents; public class GridAcademy : Academy { /// <summary> /// 陷阱 目标物体的列表 /// </summary> [HideInInspector] public List<GameObject> actorObjs; /// <summary> /// 相应玩家预制的标记值 numObstacles=2 numGoals=1 /// </summary> [HideInInspector] public int[] players; /// <summary> /// 代理 /// </summary> public GameObject trueAgent; /// <summary> /// 格子大小 gridSize x gridSize /// </summary> public int gridSize; /// <summary> /// 摄像机物体 /// </summary> public GameObject camObject; /// <summary> /// 场景相机 /// </summary> Camera cam; /// <summary> /// 代理的视觉相机 /// </summary> Camera agentCam; /// <summary> /// 代理物体预制 /// </summary> public GameObject agentPref; /// <summary> /// 目标物体预制 /// </summary> public GameObject goalPref; /// <summary> /// 陷阱物体预制 /// </summary> public GameObject pitPref; /// <summary> /// 存放物体预制 /// </summary> GameObject[] objects; /// <summary> /// 环境平台 /// </summary> GameObject plane; GameObject sN; GameObject sS; GameObject sE; GameObject sW; /// <summary> /// 各种初始化 /// </summary> public override void InitializeAcademy() { //从面板上获取填入的参数 gridSize = (int)resetParameters["gridSize"]; cam = camObject.GetComponent<Camera>(); objects = new GameObject[3] {agentPref, goalPref, pitPref}; agentCam = GameObject.Find("agentCam").GetComponent<Camera>(); actorObjs = new List<GameObject>(); plane = GameObject.Find("Plane"); sN = GameObject.Find("sN"); sS = GameObject.Find("sS"); sW = GameObject.Find("sW"); sE = GameObject.Find("sE"); } /// <summary> /// 设置环境 /// </summary> public void SetEnvironment() { //根据gridSize调整相机 cam.transform.position = new Vector3(-((int)resetParameters["gridSize"] - 1) / 2f, (int)resetParameters["gridSize"] * 1.25f, -((int)resetParameters["gridSize"] - 1) / 2f); cam.orthographicSize = ((int)resetParameters["gridSize"] + 5f) / 2f; List<int> playersList = new List<int>(); for (int i = 0; i < (int)resetParameters["numObstacles"]; i++) { playersList.Add(2); } for (int i = 0; i < (int)resetParameters["numGoals"]; i++) { playersList.Add(1); } players = playersList.ToArray(); //根据gridSize调整场景物体 plane.transform.localScale = new Vector3(gridSize / 10.0f, 1f, gridSize / 10.0f); plane.transform.position = new Vector3((gridSize - 1) / 2f, -0.5f, (gridSize - 1) / 2f); sN.transform.localScale = new Vector3(1, 1, gridSize + 2); sS.transform.localScale = new Vector3(1, 1, gridSize + 2); sN.transform.position = new Vector3((gridSize - 1) / 2f, 0.0f, gridSize); sS.transform.position = new Vector3((gridSize - 1) / 2f, 0.0f, -1); sE.transform.localScale = new Vector3(1, 1, gridSize + 2); sW.transform.localScale = new Vector3(1, 1, gridSize + 2); sE.transform.position = new Vector3(gridSize, 0.0f, (gridSize - 1) / 2f); sW.transform.position = new Vector3(-1, 0.0f, (gridSize - 1) / 2f); agentCam.orthographicSize = (gridSize) / 2f; agentCam.transform.position = new Vector3((gridSize - 1) / 2f, gridSize + 1f, (gridSize - 1) / 2f); } public override void AcademyReset() { foreach (GameObject actor in actorObjs) { DestroyImmediate(actor); } SetEnvironment(); actorObjs.Clear(); //重新设置场景里格子里的物体,根据players的数量随机生成坐标0-24号的格子的位置,HashSet不重复的位置 //后面算出对应的行,列,即位置, 比如25个格子 那10号 就是第1行 第1列(行列号从0开始) HashSet<int> numbers = new HashSet<int>(); while (numbers.Count < players.Length + 1) { numbers.Add(Random.Range(0, gridSize * gridSize)); } int[] numbersA = Enumerable.ToArray(numbers); //计算numbersA具体位置 单位为1 格子号 for (int i = 0; i < players.Length; i++) { //行 int x = (numbersA[i]) / gridSize; //列 int y = (numbersA[i]) % gridSize; GameObject actorObj = Instantiate(objects[players[i]]); actorObj.transform.position = new Vector3(x, -0.25f, y); actorObjs.Add(actorObj); } //获取numbersA后一个随机的位置 int x_a = (numbersA[players.Length]) / gridSize; int y_a = (numbersA[players.Length]) % gridSize; trueAgent.transform.position = new Vector3(x_a, -0.25f, y_a); } public override void AcademyStep() { } }还有GridAgent.cs
using System; using UnityEngine; using System.Linq; using MLAgents; public class GridAgent : Agent { [Header("Specific to GridWorld")] private GridAcademy academy; /// <summary> /// 请求决策的时间间隔 /// </summary> public float timeBetweenDecisionsAtInference; /// <summary> /// 累计间隔时间 /// </summary> private float timeSinceDecision; [Tooltip("Because we want an observation right before making a decision, we can force " + "a camera to render before making a decision. Place the agentCam here if using " + "RenderTexture as observations.")] public Camera renderCamera; /// <summary> /// 屏蔽动作,即决策不会采取某些动作 /// </summary> [Tooltip("Selecting will turn on action masking. Note that a model trained with action " + "masking turned on may not behave optimally when action masking is turned off.")] public bool maskActions = true; private const int NoAction = 0; // do nothing! private const int Up = 1; private const int Down = 2; private const int Left = 3; private const int Right = 4; public override void InitializeAgent() { academy = FindObjectOfType(typeof(GridAcademy)) as GridAcademy; } /// <summary> /// 视觉无需收集信息,设置是否屏蔽某些动作 /// </summary> public override void CollectObservations() { if (maskActions) { SetMask(); } } /// <summary> /// 屏蔽某些动作 /// </summary> private void SetMask() { // 防止代理选择碰撞墙的动作,比如5 X 5的图,最外面是墙,即边缘0,4的位置之外都是墙 var positionX = (int) transform.position.x; var positionZ = (int) transform.position.z; var maxPosition = academy.gridSize - 1; //再走就-1 了,就是墙,所以要防止再左边,下面同理 if (positionX == 0) { SetActionMask(Left); } if (positionX == maxPosition) { SetActionMask(Right); } if (positionZ == 0) { SetActionMask(Down); } if (positionZ == maxPosition) { SetActionMask(Up); } } public override void AgentAction(float[] vectorAction, string textAction) { AddReward(-0.01f); int action = Mathf.FloorToInt(vectorAction[0]); //计算出下一步的位置 Vector3 targetPos = transform.position; switch (action) { case NoAction: // do nothing break; case Right: targetPos = transform.position + new Vector3(1f, 0, 0f); break; case Left: targetPos = transform.position + new Vector3(-1f, 0, 0f); break; case Up: targetPos = transform.position + new Vector3(0f, 0, 1f); break; case Down: targetPos = transform.position + new Vector3(0f, 0, -1f); break; default: throw new ArgumentException("Invalid action value"); } Collider[] blockTest = Physics.OverlapBox(targetPos, new Vector3(0.3f, 0.3f, 0.3f)); //如果不会碰到墙,就执行里面的,碰到墙则在原地,在设置动作屏蔽的时候,可以取消这个判断,否则还是需要的,不然会一直走出 if (blockTest.Where(col => col.gameObject.CompareTag("wall")).ToArray().Length == 0) { //设置位置 transform.position = targetPos; //碰到目标 if (blockTest.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1) { Done(); SetReward(1f); } //碰到陷阱 if (blockTest.Where(col => col.gameObject.CompareTag("pit")).ToArray().Length == 1) { Done(); SetReward(-1f); } } } // 刷新环境 public override void AgentReset() { academy.AcademyReset(); } public void FixedUpdate() { WaitTimeInference(); } private void WaitTimeInference() { if(renderCamera != null) { renderCamera.Render(); } if (!academy.GetIsInference()) { RequestDecision(); } else { if (timeSinceDecision >= timeBetweenDecisionsAtInference) { timeSinceDecision = 0f; RequestDecision(); } else { timeSinceDecision += Time.fixedDeltaTime; } } } }暂时先这三个例子吧,其他的以后慢慢添加,其实主要也就是看代码,理解思路,以后自己要是做的时候可以参照着来,可能每个游戏都不一样,但是一些基本的东西应该是一样的。
好了,今天就到这里了,希望对学习理解有帮助,大神看见勿喷,仅为自己的学习理解,能力有限,请多包涵,部分图片来自网络,侵删。