Source code for etna.transforms.feature_selection.base
import warnings
from abc import ABC
from typing import List
from typing import Union
import pandas as pd
from typing_extensions import Literal
from etna.transforms import Transform
[docs]class BaseFeatureSelectionTransform(Transform, ABC):
"""Base class for feature selection transforms."""
def __init__(self, features_to_use: Union[List[str], Literal["all"]] = "all"):
self.features_to_use = features_to_use
self.selected_features: List[str] = []
def _get_features_to_use(self, df: pd.DataFrame) -> List[str]:
"""Get list of features from the dataframe to perform the selection on."""
features = set(df.columns.get_level_values("feature")) - set(["target"])
if self.features_to_use != "all":
features = features.intersection(self.features_to_use)
if sorted(features) != sorted(self.features_to_use):
warnings.warn("Columns from feature_to_use which are out of dataframe columns will be dropped!")
return sorted(features)
[docs] def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Select top_k features.
Parameters
----------
df:
dataframe with all segments data
Returns
-------
result: pd.DataFrame
Dataframe with with only selected features
"""
result = df.copy()
rest_columns = set(df.columns.get_level_values("feature")) - set(self._get_features_to_use(df))
selected_columns = sorted(self.selected_features + list(rest_columns))
result = result.loc[:, pd.IndexSlice[:, selected_columns]]
return result