271 lines
8.5 KiB
Python
271 lines
8.5 KiB
Python
|
|
import re
|
||
|
|
import inspect
|
||
|
|
from .backwardscompat import callable
|
||
|
|
|
||
|
|
# metaclass implementation idea from
|
||
|
|
# http://blog.ianbicking.org/more-on-python-metaprogramming-comment-14.html
|
||
|
|
_transition_gatherer = []
|
||
|
|
|
||
|
|
def transition(event, from_, to, action=None, guard=None):
|
||
|
|
_transition_gatherer.append([event, from_, to, action, guard])
|
||
|
|
|
||
|
|
_state_gatherer = []
|
||
|
|
|
||
|
|
def state(name, enter=None, exit=None):
|
||
|
|
_state_gatherer.append([name, enter, exit])
|
||
|
|
|
||
|
|
|
||
|
|
class MetaStateMachine(type):
|
||
|
|
|
||
|
|
def __new__(cls, name, bases, dictionary):
|
||
|
|
global _transition_gatherer, _state_gatherer
|
||
|
|
Machine = super(MetaStateMachine, cls).__new__(cls, name, bases, dictionary)
|
||
|
|
Machine._class_transitions = []
|
||
|
|
Machine._class_states = {}
|
||
|
|
for s in _state_gatherer:
|
||
|
|
Machine._add_class_state(*s)
|
||
|
|
for i in _transition_gatherer:
|
||
|
|
Machine._add_class_transition(*i)
|
||
|
|
_transition_gatherer = []
|
||
|
|
_state_gatherer = []
|
||
|
|
return Machine
|
||
|
|
|
||
|
|
|
||
|
|
StateMachineBase = MetaStateMachine('StateMachineBase', (object, ), {})
|
||
|
|
|
||
|
|
|
||
|
|
class StateMachine(StateMachineBase):
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._bring_definitions_to_object_level()
|
||
|
|
self._inject_into_parts()
|
||
|
|
self._validate_machine_definitions()
|
||
|
|
if callable(self.initial_state):
|
||
|
|
self.initial_state = self.initial_state()
|
||
|
|
self._current_state_object = self._state_by_name(self.initial_state)
|
||
|
|
self._current_state_object.run_enter(self)
|
||
|
|
self._create_state_getters()
|
||
|
|
|
||
|
|
def __new__(cls, *args, **kwargs):
|
||
|
|
obj = super(StateMachine, cls).__new__(cls)
|
||
|
|
obj._states = {}
|
||
|
|
obj._transitions = []
|
||
|
|
return obj
|
||
|
|
|
||
|
|
def _bring_definitions_to_object_level(self):
|
||
|
|
self._states.update(self.__class__._class_states)
|
||
|
|
self._transitions.extend(self.__class__._class_transitions)
|
||
|
|
|
||
|
|
def _inject_into_parts(self):
|
||
|
|
for collection in [self._states.values(), self._transitions]:
|
||
|
|
for component in collection:
|
||
|
|
component.machine = self
|
||
|
|
|
||
|
|
def _validate_machine_definitions(self):
|
||
|
|
if len(self._states) < 2:
|
||
|
|
raise InvalidConfiguration('There must be at least two states')
|
||
|
|
if not getattr(self, 'initial_state', None):
|
||
|
|
raise InvalidConfiguration('There must exist an initial state')
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def _add_class_state(cls, name, enter, exit):
|
||
|
|
cls._class_states[name] = _State(name, enter, exit)
|
||
|
|
|
||
|
|
def add_state(self, name, enter=None, exit=None):
|
||
|
|
state = _State(name, enter, exit)
|
||
|
|
setattr(self, state.getter_name(), state.getter_method().__get__(self, self.__class__))
|
||
|
|
self._states[name] = state
|
||
|
|
|
||
|
|
def _current_state_name(self):
|
||
|
|
return self._current_state_object.name
|
||
|
|
|
||
|
|
current_state = property(_current_state_name)
|
||
|
|
|
||
|
|
def changing_state(self, from_, to):
|
||
|
|
"""
|
||
|
|
This method is called whenever a state change is executed
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
def _new_state(self, state):
|
||
|
|
self.changing_state(self._current_state_object.name, state.name)
|
||
|
|
self._current_state_object = state
|
||
|
|
|
||
|
|
def _state_objects(self):
|
||
|
|
return list(self._states.values())
|
||
|
|
|
||
|
|
def states(self):
|
||
|
|
return [s.name for s in self._state_objects()]
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def _add_class_transition(cls, event, from_, to, action, guard):
|
||
|
|
transition = _Transition(event, [cls._class_states[s] for s in _listize(from_)],
|
||
|
|
cls._class_states[to], action, guard)
|
||
|
|
cls._class_transitions.append(transition)
|
||
|
|
setattr(cls, event, transition.event_method())
|
||
|
|
|
||
|
|
def add_transition(self, event, from_, to, action=None, guard=None):
|
||
|
|
transition = _Transition(event, [self._state_by_name(s) for s in _listize(from_)],
|
||
|
|
self._state_by_name(to), action, guard)
|
||
|
|
self._transitions.append(transition)
|
||
|
|
setattr(self, event, transition.event_method().__get__(self, self.__class__))
|
||
|
|
|
||
|
|
def _process_transitions(self, event_name, *args, **kwargs):
|
||
|
|
transitions = self._transitions_by_name(event_name)
|
||
|
|
transitions = self._ensure_from_validity(transitions)
|
||
|
|
this_transition = self._check_guards(transitions)
|
||
|
|
this_transition.run(self, *args, **kwargs)
|
||
|
|
|
||
|
|
def _create_state_getters(self):
|
||
|
|
for state in self._state_objects():
|
||
|
|
setattr(self, state.getter_name(), state.getter_method().__get__(self, self.__class__))
|
||
|
|
|
||
|
|
def _state_by_name(self, name):
|
||
|
|
for state in self._state_objects():
|
||
|
|
if state.name == name:
|
||
|
|
return state
|
||
|
|
|
||
|
|
def _transitions_by_name(self, name):
|
||
|
|
return list(filter(lambda transition: transition.event == name, self._transitions))
|
||
|
|
|
||
|
|
def _ensure_from_validity(self, transitions):
|
||
|
|
valid_transitions = list(filter(
|
||
|
|
lambda transition: transition.is_valid_from(self._current_state_object),
|
||
|
|
transitions))
|
||
|
|
if len(valid_transitions) == 0:
|
||
|
|
raise InvalidTransition("Cannot %s from %s" % (
|
||
|
|
transitions[0].event, self.current_state))
|
||
|
|
return valid_transitions
|
||
|
|
|
||
|
|
def _check_guards(self, transitions):
|
||
|
|
allowed_transitions = []
|
||
|
|
for transition in transitions:
|
||
|
|
if transition.check_guard(self):
|
||
|
|
allowed_transitions.append(transition)
|
||
|
|
if len(allowed_transitions) == 0:
|
||
|
|
raise GuardNotSatisfied("Guard is not satisfied for this transition")
|
||
|
|
elif len(allowed_transitions) > 1:
|
||
|
|
raise ForkedTransition("More than one transition was allowed for this event")
|
||
|
|
return allowed_transitions[0]
|
||
|
|
|
||
|
|
|
||
|
|
class _Transition(object):
|
||
|
|
|
||
|
|
def __init__(self, event, from_, to, action, guard):
|
||
|
|
self.event = event
|
||
|
|
self.from_ = from_
|
||
|
|
self.to = to
|
||
|
|
self.action = action
|
||
|
|
self.guard = _Guard(guard)
|
||
|
|
|
||
|
|
def event_method(self):
|
||
|
|
def generated_event(machine, *args, **kwargs):
|
||
|
|
these_transitions = machine._process_transitions(self.event, *args, **kwargs)
|
||
|
|
generated_event.__doc__ = 'event %s' % self.event
|
||
|
|
generated_event.__name__ = self.event
|
||
|
|
return generated_event
|
||
|
|
|
||
|
|
def is_valid_from(self, from_):
|
||
|
|
return from_ in _listize(self.from_)
|
||
|
|
|
||
|
|
def check_guard(self, machine):
|
||
|
|
return self.guard.check(machine)
|
||
|
|
|
||
|
|
def run(self, machine, *args, **kwargs):
|
||
|
|
machine._current_state_object.run_exit(machine)
|
||
|
|
machine._new_state(self.to)
|
||
|
|
self.to.run_enter(machine)
|
||
|
|
_ActionRunner(machine).run(self.action, *args, **kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
class _Guard(object):
|
||
|
|
|
||
|
|
def __init__(self, action):
|
||
|
|
self.action = action
|
||
|
|
|
||
|
|
def check(self, machine):
|
||
|
|
if self.action is None:
|
||
|
|
return True
|
||
|
|
items = _listize(self.action)
|
||
|
|
result = True
|
||
|
|
for item in items:
|
||
|
|
result = result and self._evaluate(machine, item)
|
||
|
|
return result
|
||
|
|
|
||
|
|
def _evaluate(self, machine, item):
|
||
|
|
if callable(item):
|
||
|
|
return item(machine)
|
||
|
|
else:
|
||
|
|
guard = getattr(machine, item)
|
||
|
|
if callable(guard):
|
||
|
|
guard = guard()
|
||
|
|
return guard
|
||
|
|
|
||
|
|
|
||
|
|
class _State(object):
|
||
|
|
|
||
|
|
def __init__(self, name, enter, exit):
|
||
|
|
self.name = name
|
||
|
|
self.enter = enter
|
||
|
|
self.exit = exit
|
||
|
|
|
||
|
|
def getter_name(self):
|
||
|
|
return 'is_%s' % self.name
|
||
|
|
|
||
|
|
def getter_method(self):
|
||
|
|
def state_getter(self_machine):
|
||
|
|
return self_machine.current_state == self.name
|
||
|
|
return state_getter
|
||
|
|
|
||
|
|
def run_enter(self, machine):
|
||
|
|
_ActionRunner(machine).run(self.enter)
|
||
|
|
|
||
|
|
def run_exit(self, machine):
|
||
|
|
_ActionRunner(machine).run(self.exit)
|
||
|
|
|
||
|
|
|
||
|
|
class _ActionRunner(object):
|
||
|
|
|
||
|
|
def __init__(self, machine):
|
||
|
|
self.machine = machine
|
||
|
|
|
||
|
|
def run(self, action_param, *args, **kwargs):
|
||
|
|
if not action_param:
|
||
|
|
return
|
||
|
|
action_items = _listize(action_param)
|
||
|
|
for action_item in action_items:
|
||
|
|
self._run_action(action_item, *args, **kwargs)
|
||
|
|
|
||
|
|
def _run_action(self, action, *args, **kwargs):
|
||
|
|
if callable(action):
|
||
|
|
self._try_to_run_with_args(action, self.machine, *args, **kwargs)
|
||
|
|
else:
|
||
|
|
self._try_to_run_with_args(getattr(self.machine, action), *args, **kwargs)
|
||
|
|
|
||
|
|
def _try_to_run_with_args(self, action, *args, **kwargs):
|
||
|
|
try:
|
||
|
|
action(*args, **kwargs)
|
||
|
|
except TypeError:
|
||
|
|
action()
|
||
|
|
|
||
|
|
|
||
|
|
class InvalidConfiguration(Exception):
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class InvalidTransition(Exception):
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class GuardNotSatisfied(Exception):
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class ForkedTransition(Exception):
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
def _listize(value):
|
||
|
|
return type(value) in [list, tuple] and value or [value]
|
||
|
|
|