# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .euler import Euler from .euler_heun import EulerHeun from .heun import Heun from .log_ode import LogODEMidpoint from .midpoint import Midpoint from .milstein import MilsteinIto, MilsteinStratonovich from .reversible_heun import ReversibleHeun, AdjointReversibleHeun from .srk import SRK from ...settings import METHODS, SDE_TYPES def select(method, sde_type): if method == METHODS.euler: return Euler elif method == METHODS.milstein and sde_type == SDE_TYPES.ito: return MilsteinIto elif method == METHODS.srk: return SRK elif method == METHODS.midpoint: return Midpoint elif method == METHODS.reversible_heun: return ReversibleHeun elif method == METHODS.adjoint_reversible_heun: return AdjointReversibleHeun elif method == METHODS.heun: return Heun elif method == METHODS.milstein and sde_type == SDE_TYPES.stratonovich: return MilsteinStratonovich elif method == METHODS.log_ode_midpoint: return LogODEMidpoint elif method == METHODS.euler_heun: return EulerHeun else: raise ValueError(f"Method '{method}' does not match any known method.")