【Unity】より汎用的な有限ステートマシンを実装する

前回、有限ステートマシンの基本的な実装について書きました。

light11.hatenadiary.com

今回はこの基礎を踏まえて、より汎用的で使いやすいステートマシンを作ってみます。

Unity2018.3.1

クラス外から遷移の登録とトリガーの呼び出しをできるように

この記事の前提として、前回の記事の最後のソースコードを元にします。

light11.hatenadiary.com

まず、遷移情報の登録と遷移のためのトリガーの呼び出しをクラス外から行えるようにします。
ソースコードは下記の通りです。

using System.Collections.Generic;
using System.Linq;
using UnityEngine;

public enum StateType
{
    Idle,
    Run,
    Jump,
}

public abstract class State
{
    public abstract void OnEnter();
    public abstract void OnExit();
    public abstract void OnUpdate(float deltaTime);
}
public class IdleState : State
{
    public override void OnEnter()
    {
        Debug.Log("Enter : Idle");
    }
    public override void OnExit()
    {
        Debug.Log("Exit : Idle");
    }
    public override void OnUpdate(float deltaTime)
    {
        Debug.Log("Update : Idle");
    }
}
public class RunState : State
{
    public override void OnEnter()
    {
        Debug.Log("Enter : Run");
    }
    public override void OnExit()
    {
        Debug.Log("Exit : Run");
    }
    public override void OnUpdate(float deltaTime)
    {
        Debug.Log("Update : Run");
    }
}
public class JumpState : State
{
    public override void OnEnter()
    {
        Debug.Log("Enter : Jump");
    }
    public override void OnExit()
    {
        Debug.Log("Exit : Jump");
    }
    public override void OnUpdate(float deltaTime)
    {
        Debug.Log("Update : Jump");
    }
}

public class Transition{
    public StateType To { get; set; }
    public TriggerType Trigger { get; set; }
}

public enum TriggerType
{
    KeyDownI,
    KeyDownJ,
    KeyDownR,
}

public class StateMachine
{
    private StateType _stateType;
    private State _state;
    
    private Dictionary<StateType, State> _stateTypes = new Dictionary<StateType, State>();
    private Dictionary<StateType, List<Transition>> _transitionLists = new Dictionary<StateType, List<Transition>>();
    
    public StateMachine(StateType initialState)
    {
        _stateTypes.Add(StateType.Idle, new IdleState());
        _stateTypes.Add(StateType.Run, new RunState());
        _stateTypes.Add(StateType.Jump, new JumpState());

        // 最初のStateに遷移
        ChangeState(initialState);
    }

    /// <summary>
    /// トリガーを実行する
    /// </summary>
    public void ExecuteTrigger(TriggerType trigger)
    {
        var transitions = _transitionLists[_stateType];
        foreach (var transition in transitions)
        {
            if (transition.Trigger == trigger)
            {
                ChangeState(transition.To);
                break;
            }
        }
    }
    
    /// <summary>
    /// 遷移情報を登録する
    /// </summary>
    public void AddTransition(StateType from, StateType to, TriggerType trigger)
    {
        if (!_transitionLists.ContainsKey(from))
        {
            _transitionLists.Add(from, new List<Transition>());
        }
        var transitions = _transitionLists[from];
        var transition = transitions.FirstOrDefault(x => x.To == to);
        if (transition == null)
        {
            // 新規登録
            transitions.Add(new Transition { To = to, Trigger = trigger });
        }
        else
        {
            // 更新
            transition.To = to;
            transition.Trigger = trigger;
        }
    }
    
    /// <summary>
    /// 更新する
    /// </summary>
    public void Update(float deltaTime)
    {
        _state.OnUpdate(deltaTime);
    }

    /// <summary>
    /// Stateを直接変更する
    /// </summary>
    private void ChangeState(StateType stateType)
    {
        if (_state != null) {
            _state.OnExit();
        }

        _stateType = stateType;
        _state = _stateTypes[_stateType];
        _state.OnEnter();
    }
}

これをこんな感じで使います。

using UnityEngine;

public class StateMachineDemo : MonoBehaviour {

    private StateMachine _stateMachine;

    private void Start () {
        // StateMachineを取得
        _stateMachine = new StateMachine(StateType.Idle);

        // 遷移情報を登録
        _stateMachine.AddTransition(StateType.Idle, StateType.Run, TriggerType.KeyDownR);
        _stateMachine.AddTransition(StateType.Idle, StateType.Jump, TriggerType.KeyDownJ);
        _stateMachine.AddTransition(StateType.Run, StateType.Idle, TriggerType.KeyDownI);
        _stateMachine.AddTransition(StateType.Jump, StateType.Idle, TriggerType.KeyDownI);
    }

    private void Update()
    {
        // トリガーを呼ぶ
        if (Input.GetKeyDown(KeyCode.I)) _stateMachine.ExecuteTrigger(TriggerType.KeyDownI);
        if (Input.GetKeyDown(KeyCode.J)) _stateMachine.ExecuteTrigger(TriggerType.KeyDownJ);
        if (Input.GetKeyDown(KeyCode.R)) _stateMachine.ExecuteTrigger(TriggerType.KeyDownR);

        // ステートマシンを更新
        _stateMachine.Update(Time.deltaTime);
    }
}

少し使いやすくなりました。

状態も外から登録できるようにする

次に各Stateもクラス外から登録できるようにします。

using System.Collections.Generic;
using System.Linq;
using System;

// 各State毎のdelagateを登録しておくクラス
public class StateMapping
{
    public Action onEnter;
    public Action onExit;
    public Action<float> onUpdate;
}

public class Transition<TState, TTrigger>
{
    public TState To { get; set; }
    public TTrigger Trigger { get; set; }
}

public class StateMachine<TState, TTrigger>
    where TState : struct, IConvertible, IComparable
    where TTrigger : struct, IConvertible, IComparable
{
    private TState _stateType;
    private StateMapping _stateMapping;
    
    private Dictionary<object, StateMapping> _stateMappings = new Dictionary<object, StateMapping>();
    private Dictionary<TState, List<Transition<TState, TTrigger>>> _transitionLists = new Dictionary<TState, List<Transition<TState, TTrigger>>>();
    
    public StateMachine(TState initialState)
    {
        // StateからStateMappingを作成
        var enumValues  = Enum.GetValues(typeof(TState));
        for (int i = 0; i < enumValues.Length; i++)
        {
            var mapping = new StateMapping();
            _stateMappings.Add(enumValues.GetValue(i), mapping);
        }
        
        // 最初のStateに遷移
        ChangeState(initialState);
    }

    /// <summary>
    /// トリガーを実行する
    /// </summary>
    public void ExecuteTrigger(TTrigger trigger)
    {
        var transitions = _transitionLists[_stateType];
        foreach (var transition in transitions)
        {
            if (transition.Trigger.Equals(trigger))
            {
                ChangeState(transition.To);
                break;
            }
        }
    }
    
    /// <summary>
    /// 遷移情報を登録する
    /// </summary>
    public void AddTransition(TState from, TState to, TTrigger trigger)
    {
        if (!_transitionLists.ContainsKey(from))
        {
            _transitionLists.Add(from, new List<Transition<TState, TTrigger>>());
        }
        var transitions = _transitionLists[from];
        var transition = transitions.FirstOrDefault(x => x.To.Equals(to));
        if (transition == null)
        {
            // 新規登録
            transitions.Add(new Transition<TState, TTrigger> { To = to, Trigger = trigger });
        }
        else
        {
            // 更新
            transition.To = to;
            transition.Trigger = trigger;
        }
    }

    /// <summary>
    /// Stateを初期化する
    /// </summary>
    public void SetupState(TState state, Action onEnter, Action onExit, Action<float> onUpdate)
    {
        var stateMapping = _stateMappings[state];
        stateMapping.onEnter = onEnter;
        stateMapping.onExit = onExit;
        stateMapping.onUpdate = onUpdate;
    }

    /// <summary>
    /// 更新する
    /// </summary>
    public void Update(float deltaTime)
    {
        if (_stateMapping != null && _stateMapping.onUpdate != null) {
            _stateMapping.onUpdate(deltaTime);
        }
    }

    /// <summary>
    /// Stateを直接変更する
    /// </summary>
    private void ChangeState(TState to)
    {
        // OnExit
        if (_stateMapping != null && _stateMapping.onExit != null) {
            _stateMapping.onExit();
        }

        // OnEnter
        _stateType = to;
        _stateMapping = _stateMappings[to];
        if (_stateMapping.onEnter != null) {
            _stateMapping.onEnter();
        }
    }
}

使う側はこんな感じです。

using UnityEngine;

public class StateMachineDemo : MonoBehaviour {

    public enum StateType
    {
        Idle,
        Run,
        Jump,
    }

    public enum TriggerType
    {
        KeyDownI,
        KeyDownJ,
        KeyDownR,
    }

    private StateMachine<StateType, TriggerType> _stateMachine;

    private void Start () {
        // StateMachineを生成
        _stateMachine = new StateMachine<StateType, TriggerType>(StateType.Idle);

        // 遷移情報を登録
        _stateMachine.AddTransition(StateType.Idle, StateType.Run, TriggerType.KeyDownR);
        _stateMachine.AddTransition(StateType.Idle, StateType.Jump, TriggerType.KeyDownJ);
        _stateMachine.AddTransition(StateType.Run, StateType.Idle, TriggerType.KeyDownI);
        _stateMachine.AddTransition(StateType.Jump, StateType.Idle, TriggerType.KeyDownI);
        
        _stateMachine.SetupState(StateType.Idle, () => Debug.Log("OnEnter: Idle"), () => Debug.Log("OnExit ; Idle"), deltaTime => Debug.Log("OnUpdate: Idle"));
        _stateMachine.SetupState(StateType.Run, () => Debug.Log("OnEnter: Run"), () => Debug.Log("OnExit ; Run"), deltaTime => Debug.Log("OnUpdate: Run"));
        _stateMachine.SetupState(StateType.Jump, () => Debug.Log("OnEnter: Jump"), () => Debug.Log("OnExit ; Jump"), deltaTime => Debug.Log("OnUpdate: Jump"));
    }

    private void Update()
    {
        // トリガーを呼ぶ
        if (Input.GetKeyDown(KeyCode.I)) _stateMachine.ExecuteTrigger(TriggerType.KeyDownI);
        if (Input.GetKeyDown(KeyCode.J)) _stateMachine.ExecuteTrigger(TriggerType.KeyDownJ);
        if (Input.GetKeyDown(KeyCode.R)) _stateMachine.ExecuteTrigger(TriggerType.KeyDownR);

        // ステートマシンを更新
        _stateMachine.Update(Time.deltaTime);
    }
}

これでStateMachineとしての機能だけを分離できました。
使い方次第でいろんな用途に使える汎用的なStateMachineになりました。

コルーチンに対応

最後にコルーチンも使えるようにしてみます。

using System.Collections.Generic;
using System.Collections;
using System.Linq;
using System;
using UnityEngine;

// 各State毎のdelagateを登録しておくクラス
public class StateMapping
{
    public Action onEnter;
    public Func<IEnumerator> EnterRoutine;
    public Action onExit;
    public Func<IEnumerator> ExitRoutine;
    public Action<float> onUpdate;
}

public class Transition<TState, TTrigger>
{
    public TState To { get; set; }
    public TTrigger Trigger { get; set; }
}

public class StateMachine<TState, TTrigger>
    where TState : struct, IConvertible, IComparable
    where TTrigger : struct, IConvertible, IComparable
{
    private MonoBehaviour _monoBehaviour;
    private TState _stateType;
    private StateMapping _stateMapping;
    // 遷移中である場合の遷移先
    private TState? _destinationState;
    // 遷移中か
    private bool _inExitTransition;
    private bool _inEnterTransition;
    
    private Dictionary<object, StateMapping> _stateMappings = new Dictionary<object, StateMapping>();
    private Dictionary<TState, List<Transition<TState, TTrigger>>> _transitionLists = new Dictionary<TState, List<Transition<TState, TTrigger>>>();
    
    public StateMachine(MonoBehaviour monoBehaviour, TState initialState)
    {
        _monoBehaviour = monoBehaviour;

        // StateからStateMappingを作成
        var enumValues  = Enum.GetValues(typeof(TState));
        for (int i = 0; i < enumValues.Length; i++)
        {
            var mapping = new StateMapping();
            _stateMappings.Add(enumValues.GetValue(i), mapping);
        }
        
        // 最初のStateに遷移
        ChangeStateImmediately(initialState);
    }

    /// <summary>
    /// トリガーを実行する
    /// </summary>
    public void ExecuteTrigger(TTrigger trigger)
    {
        var transitions = _transitionLists[_stateType];
        foreach (var transition in transitions)
        {
            if (transition.Trigger.Equals(trigger))
            {
                _monoBehaviour.StartCoroutine(ChangeState(transition.To));
                break;
            }
        }
    }
    
    /// <summary>
    /// 遷移情報を登録する
    /// </summary>
    public void AddTransition(TState from, TState to, TTrigger trigger)
    {
        if (!_transitionLists.ContainsKey(from))
        {
            _transitionLists.Add(from, new List<Transition<TState, TTrigger>>());
        }
        var transitions = _transitionLists[from];
        var transition = transitions.FirstOrDefault(x => x.To.Equals(to));
        if (transition == null)
        {
            // 新規登録
            transitions.Add(new Transition<TState, TTrigger> { To = to, Trigger = trigger });
        }
        else
        {
            // 更新
            transition.To = to;
            transition.Trigger = trigger;
        }
    }

    /// <summary>
    /// Stateを初期化する
    /// </summary>
    public void SetupState(TState state, Action onEnter = null, Func<IEnumerator> enterRoutine = null, Action onExit = null, Func<IEnumerator> exitRoutine = null, Action<float> onUpdate = null)
    {
        var stateMapping = _stateMappings[state];
        stateMapping.onEnter = onEnter;
        stateMapping.EnterRoutine = enterRoutine;
        stateMapping.onExit = onExit;
        stateMapping.ExitRoutine = exitRoutine;
        stateMapping.onUpdate = onUpdate;
    }

    /// <summary>
    /// 更新する
    /// </summary>
    public void Update(float deltaTime)
    {
        if (_inExitTransition || _inEnterTransition) {
            // 遷移中は更新しない
            return;
        }
        if (_stateMapping != null && _stateMapping.onUpdate != null) {
            _stateMapping.onUpdate(deltaTime);
        }
    }
    
    /// <summary>
    /// Stateをただちに変更する
    /// </summary>
    private void ChangeStateImmediately(TState to)
    {
        // Exit
        if (_stateMapping != null) {
            if (_stateMapping.onExit != null) {
                _stateMapping.onExit();
            }
        }
        
        // Enter
        _stateType = to;
        _stateMapping = _stateMappings[_stateType];
        if (_stateMapping.onEnter != null) {
            _stateMapping.onEnter();
        }
    }

    /// <summary>
    /// Stateを変更する
    /// </summary>
    private IEnumerator ChangeState(TState to)
    {
        if (_inEnterTransition) {
            // Enter遷移中だったら何もせずbreak(状態遷移失敗)
            yield break;
        }

        _destinationState = to;
        if (_inExitTransition) {
            // Exit遷移中だったら遷移先を上書きしてbreak
            yield break;
        }
        

        // Exit
        _inExitTransition = true;
        if (_stateMapping != null) {
            if (_stateMapping.ExitRoutine != null) {
                yield return _monoBehaviour.StartCoroutine(_stateMapping.ExitRoutine());
            }
            if (_stateMapping.onExit != null) {
                _stateMapping.onExit();
            }
        }
        _inExitTransition = false;

        // Enter
        _inEnterTransition = true;
        var stateMapping = _stateMappings[_destinationState.Value];
        if (stateMapping.EnterRoutine != null) {
            yield return stateMapping.EnterRoutine();
        }
        if (stateMapping.onEnter != null) {
            stateMapping.onEnter();
        }
        _inEnterTransition = false;
        // Stateを書き換え
        _stateType = _destinationState.Value;
        _stateMapping = _stateMappings[_stateType];

        _destinationState = null;
    }
}

使う側はこんな感じです。

using System.Collections.Generic;
using System.Collections;
using UnityEngine;

public class StateMachineDemo : MonoBehaviour {

    public enum StateType
    {
        Idle,
        Run,
        Jump,
    }

    public enum TriggerType
    {
        KeyDownI,
        KeyDownJ,
        KeyDownR,
    }

    private StateMachine<StateType, TriggerType> _stateMachine;

    private void Start () {
        // StateMachineを生成
        _stateMachine = new StateMachine<StateType, TriggerType>(this, StateType.Idle);

        // 遷移情報を登録
        _stateMachine.AddTransition(StateType.Idle, StateType.Run, TriggerType.KeyDownR);
        _stateMachine.AddTransition(StateType.Idle, StateType.Jump, TriggerType.KeyDownJ);
        _stateMachine.AddTransition(StateType.Run, StateType.Idle, TriggerType.KeyDownI);
        _stateMachine.AddTransition(StateType.Jump, StateType.Idle, TriggerType.KeyDownI);
        
        // Stateごとのふるまいを登録
        _stateMachine.SetupState(StateType.Idle, () => Debug.Log("OnEnter: Idle"), () => EnterRoutine(StateType.Idle), () => Debug.Log("OnExit ; Idle"), () => ExitRoutine(StateType.Idle));
        _stateMachine.SetupState(StateType.Run, () => Debug.Log("OnEnter: Run"), () => EnterRoutine(StateType.Run), () => Debug.Log("OnExit ; Run"), () => ExitRoutine(StateType.Run));
        _stateMachine.SetupState(StateType.Jump, () => Debug.Log("OnEnter: Jump"), () => EnterRoutine(StateType.Jump), () => Debug.Log("OnExit ; Jump"), () => ExitRoutine(StateType.Jump));
    }

    private void Update()
    {
        // トリガーを呼ぶ
        if (Input.GetKeyDown(KeyCode.I)) _stateMachine.ExecuteTrigger(TriggerType.KeyDownI);
        if (Input.GetKeyDown(KeyCode.J)) _stateMachine.ExecuteTrigger(TriggerType.KeyDownJ);
        if (Input.GetKeyDown(KeyCode.R)) _stateMachine.ExecuteTrigger(TriggerType.KeyDownR);

        // ステートマシンを更新
        _stateMachine.Update(Time.deltaTime);
    }

    private IEnumerator EnterRoutine(StateType stateType)
    {
        Debug.Log(stateType + " : Enter routine start.");
        yield return new WaitForSeconds(1);
        Debug.Log(stateType + " : Enter routine end.");
    }

    private IEnumerator ExitRoutine(StateType stateType)
    {
        Debug.Log(stateType + " : Enter routine start.");
        yield return new WaitForSeconds(1);
        Debug.Log(stateType + " : Enter routine end.");
    }
}

Stateに入るときに重い処理をする場合などに便利になりました。

参考

GitHub - thefuntastic/Unity3d-Finite-State-Machine: An intuitive Unity3d finite state machine (FSM). Designed with an emphasis on usability, without sacrificing utility.