Source code for etna.transforms.feature_selection.gale_shapley

import warnings
from math import ceil
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

import pandas as pd
from typing_extensions import Literal

from etna.analysis import RelevanceTable
from etna.core import BaseMixin
from etna.transforms.feature_selection.base import BaseFeatureSelectionTransform


[docs]class BaseGaleShapley(BaseMixin): """Base class for a member of Gale-Shapley matching.""" def __init__(self, name: str, ranked_candidates: List[str]): """Init BaseGaleShapley. Parameters ---------- name: name of object ranked_candidates: list of preferences for the object ranked descending by importance """ self.name = name self.ranked_candidate = ranked_candidates self.candidates_rank = {candidate: i for i, candidate in enumerate(self.ranked_candidate)} self.tmp_match: Optional[str] = None self.tmp_match_rank: Optional[int] = None self.is_available = True
[docs] def update_tmp_match(self, name: str): """Create match with object name. Parameters ---------- name: name of candidate to match """ self.tmp_match = name self.tmp_match_rank = self.candidates_rank[name] self.is_available = False
[docs] def reset_tmp_match(self): """Break tmp current.""" self.tmp_match = None self.tmp_match_rank = None self.is_available = True
[docs]class SegmentGaleShapley(BaseGaleShapley): """Class for segment member of Gale-Shapley matching.""" def __init__(self, name: str, ranked_candidates: List[str]): """Init SegmentGaleShapley. Parameters ---------- name: name of segment ranked_candidates: list of features sorted descending by importance """ super().__init__(name=name, ranked_candidates=ranked_candidates) self.last_candidate: Optional[int] = None
[docs] def update_tmp_match(self, name: str): """Create match with given feature. Parameters ---------- name: name of feature to match """ super().update_tmp_match(name=name) self.last_candidate = self.tmp_match_rank
[docs] def get_next_candidate(self) -> Optional[str]: """Get name of the next feature to try. Returns ------- name: str name of feature """ if self.last_candidate is None: self.last_candidate = 0 else: self.last_candidate += 1 if self.last_candidate >= len(self.ranked_candidate): return None return self.ranked_candidate[self.last_candidate]
[docs]class FeatureGaleShapley(BaseGaleShapley): """Class for feature member of Gale-Shapley matching."""
[docs] def check_segment(self, segment: str) -> bool: """Check if given segment is better than current match according to preference list. Parameters ---------- segment: segment to check Returns ------- is_better: bool returns True if given segment is a better candidate than current match. """ if self.tmp_match is None or self.tmp_match_rank is None: return True return self.candidates_rank[segment] < self.tmp_match_rank
[docs]class GaleShapleyMatcher(BaseMixin): """Class for handling Gale-Shapley matching algo.""" def __init__(self, segments: List[SegmentGaleShapley], features: List[FeatureGaleShapley]): """Init GaleShapleyMatcher. Parameters ---------- segments: list of segments to build matches features: list of features to build matches """ self.segments = segments self.features = features self.segment_by_name = {segment.name: segment for segment in self.segments} self.feature_by_name = {feature.name: feature for feature in self.features}
[docs] @staticmethod def match(segment: SegmentGaleShapley, feature: FeatureGaleShapley): """Build match between segment and feature. Parameters ---------- segment: segment to match feature: feature to match """ segment.update_tmp_match(name=feature.name) feature.update_tmp_match(name=segment.name)
[docs] @staticmethod def break_match(segment: SegmentGaleShapley, feature: FeatureGaleShapley): """Break match between segment and feature. Parameters ---------- segment: segment to break match feature: feature to break match """ segment.reset_tmp_match() feature.reset_tmp_match()
def _gale_shapley_iteration(self, available_segments: List[SegmentGaleShapley]) -> bool: """ Run iteration of Gale-Shapley matching for given available_segments. Parameters ---------- available_segments: list of segments that have no match at this iteration Returns ------- success: bool True if there is at least one match attempt at the iteration Notes ----- Success code is necessary because in ETNA usage we can not guarantee that number of features will be big enough to build matches with all the segments. In case ``n_features < n_segments`` some segments always stay available that can cause infinite while loop in ``__call__``. """ success = False for segment in available_segments: next_feature_candidate_name = segment.get_next_candidate() if next_feature_candidate_name is None: continue next_feature_candidate = self.feature_by_name[next_feature_candidate_name] success = True if next_feature_candidate.check_segment(segment=segment.name): if not next_feature_candidate.is_available: # is_available = tmp_match is not None self.break_match( segment=self.segment_by_name[next_feature_candidate.tmp_match], # type: ignore feature=next_feature_candidate, ) self.match(segment=segment, feature=next_feature_candidate) return success def _get_available_segments(self) -> List[SegmentGaleShapley]: """Get list of available segments.""" return [segment for segment in self.segments if segment.is_available] def __call__(self) -> Dict[str, str]: """Run matching. Returns ------- matching: Dict[str, str] matching dict of segment x feature """ success_run = True available_segments = self._get_available_segments() while available_segments and success_run: success_run = self._gale_shapley_iteration(available_segments=available_segments) available_segments = self._get_available_segments() return {segment.name: segment.tmp_match for segment in self.segments if segment.tmp_match is not None}
[docs]class GaleShapleyFeatureSelectionTransform(BaseFeatureSelectionTransform): """GaleShapleyFeatureSelectionTransform provides feature filtering with Gale-Shapley matching algo according to relevance table. Notes ----- Transform works with any type of features, however most of the models works only with regressors. Therefore, it is recommended to pass the regressors into the feature selection transforms. """ def __init__( self, relevance_table: RelevanceTable, top_k: int, features_to_use: Union[List[str], Literal["all"]] = "all", use_rank: bool = False, **relevance_params, ): """Init GaleShapleyFeatureSelectionTransform. Parameters ---------- relevance_table: class to build relevance table top_k: number of features that should be selected from all the given ones features_to_use: columns of the dataset to select from if "all" value is given, all columns are used use_rank: if True, use rank in relevance table computation """ super().__init__(features_to_use=features_to_use) self.relevance_table = relevance_table self.top_k = top_k self.use_rank = use_rank self.greater_is_better = False if use_rank else relevance_table.greater_is_better self.relevance_params = relevance_params def _compute_relevance_table(self, df: pd.DataFrame, features: List[str]) -> pd.DataFrame: """Compute relevance table with given data.""" targets_df = df.loc[:, pd.IndexSlice[:, "target"]] features_df = df.loc[:, pd.IndexSlice[:, features]] table = self.relevance_table( df=targets_df, df_exog=features_df, return_ranks=self.use_rank, **self.relevance_params ) return table @staticmethod def _get_ranked_list(table: pd.DataFrame, ascending: bool) -> Dict[str, List[str]]: """Get ranked lists of candidates from table of relevance.""" ranked_features = {key: list(table.loc[key].sort_values(ascending=ascending).index) for key in table.index} return ranked_features @staticmethod def _compute_gale_shapley_steps_number(top_k: int, n_segments: int, n_features: int) -> int: """Get number of necessary Gale-Shapley algo iterations.""" if n_features < top_k: warnings.warn( f"Given top_k={top_k} is bigger than n_features={n_features}. " f"Transform will not filter features." ) return 1 if top_k < n_segments: warnings.warn( f"Given top_k={top_k} is less than n_segments. Algo will filter data without Gale-Shapley run." ) return 1 return ceil(top_k / n_segments) @staticmethod def _gale_shapley_iteration( segment_features_ranking: Dict[str, List[str]], feature_segments_ranking: Dict[str, List[str]], ) -> Dict[str, str]: """Build matching for all the segments. Parameters ---------- segment_features_ranking: dict of relevance segment x sorted features Returns ------- matching dict: Dict[str, str] dict of segment x feature """ gssegments = [ SegmentGaleShapley( name=name, ranked_candidates=ranked_candidates, ) for name, ranked_candidates in segment_features_ranking.items() ] gsfeatures = [ FeatureGaleShapley(name=name, ranked_candidates=ranked_candidates) for name, ranked_candidates in feature_segments_ranking.items() ] matcher = GaleShapleyMatcher(segments=gssegments, features=gsfeatures) new_matches = matcher() return new_matches @staticmethod def _update_ranking_list( segment_features_ranking: Dict[str, List[str]], features_to_drop: List[str] ) -> Dict[str, List[str]]: """Delete chosen features from candidates ranked lists.""" for segment in segment_features_ranking: for feature in features_to_drop: segment_features_ranking[segment].remove(feature) return segment_features_ranking @staticmethod def _process_last_step( matches: Dict[str, str], relevance_table: pd.DataFrame, n: int, greater_is_better: bool ) -> List[str]: """Choose n features from given ones according to relevance_matrix.""" features_relevance = {feature: relevance_table[feature][segment] for segment, feature in matches.items()} sorted_features = sorted(features_relevance.items(), key=lambda item: item[1], reverse=greater_is_better) selected_features = [feature[0] for feature in sorted_features][:n] return selected_features
[docs] def fit(self, df: pd.DataFrame) -> "GaleShapleyFeatureSelectionTransform": """Fit Gale-Shapley algo and find a pool of ``top_k`` features. Parameters ---------- df: dataframe to fit algo """ features = self._get_features_to_use(df=df) relevance_table = self._compute_relevance_table(df=df, features=features) segment_features_ranking = self._get_ranked_list( table=relevance_table, ascending=not self.relevance_table.greater_is_better ) feature_segments_ranking = self._get_ranked_list( table=relevance_table.T, ascending=not self.relevance_table.greater_is_better ) gale_shapley_steps_number = self._compute_gale_shapley_steps_number( top_k=self.top_k, n_segments=len(segment_features_ranking), n_features=len(feature_segments_ranking), ) last_step_features_number = self.top_k % len(segment_features_ranking) for step in range(gale_shapley_steps_number): matches = self._gale_shapley_iteration( segment_features_ranking=segment_features_ranking, feature_segments_ranking=feature_segments_ranking, ) if step == gale_shapley_steps_number - 1: selected_features = self._process_last_step( matches=matches, relevance_table=relevance_table, n=last_step_features_number, greater_is_better=self.greater_is_better, ) else: selected_features = list(matches.values()) self.selected_features.extend(selected_features) segment_features_ranking = self._update_ranking_list( segment_features_ranking=segment_features_ranking, features_to_drop=selected_features ) return self