Source code for etna.ensembles.base

from typing import List

from etna.datasets import TSDataset
from etna.loggers import tslogger
from etna.pipeline.base import BasePipeline


[docs]class EnsembleMixin: """Base mixin for the ensembles.""" @staticmethod def _validate_pipeline_number(pipelines: List[BasePipeline]): """Check that given valid number of pipelines.""" if len(pipelines) < 2: raise ValueError("At least two pipelines are expected.") @staticmethod def _get_horizon(pipelines: List[BasePipeline]) -> int: """Get ensemble's horizon.""" horizons = set([pipeline.horizon for pipeline in pipelines]) if len(horizons) > 1: raise ValueError("All the pipelines should have the same horizon.") return horizons.pop() @staticmethod def _fit_pipeline(pipeline: BasePipeline, ts: TSDataset) -> BasePipeline: """Fit given pipeline with ``ts``.""" tslogger.log(msg=f"Start fitting {pipeline}.") pipeline.fit(ts=ts) tslogger.log(msg=f"Pipeline {pipeline} is fitted.") return pipeline @staticmethod def _forecast_pipeline(pipeline: BasePipeline) -> TSDataset: """Make forecast with given pipeline.""" tslogger.log(msg=f"Start forecasting with {pipeline}.") forecast = pipeline.forecast() tslogger.log(msg=f"Forecast is done with {pipeline}.") return forecast