Source code for etna.transforms.timestamp.holiday
import datetime
from typing import Optional
import holidays
import numpy as np
import pandas as pd
from etna.transforms.base import FutureMixin
from etna.transforms.base import Transform
[docs]class HolidayTransform(Transform, FutureMixin):
"""HolidayTransform generates series that indicates holidays in given dataframe."""
def __init__(self, iso_code: str = "RUS", out_column: Optional[str] = None):
"""
Create instance of HolidayTransform.
Parameters
----------
iso_code:
internationally recognised codes, designated to country for which we want to find the holidays
out_column:
name of added column. Use ``self.__repr__()`` if not given.
"""
self.iso_code = iso_code
self.holidays = holidays.CountryHoliday(iso_code)
self.out_column = out_column
self.out_column = self.out_column if self.out_column is not None else self.__repr__()
[docs] def fit(self, df: pd.DataFrame) -> "HolidayTransform":
"""
Fit HolidayTransform with data from df. Does nothing in this case.
Parameters
----------
df: pd.DataFrame
value series with index column in timestamp format
"""
return self
[docs] def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform data from df with HolidayTransform and generate a column of holidays flags.
Parameters
----------
df: pd.DataFrame
value series with index column in timestamp format
Returns
-------
:
pd.DataFrame with added holidays
"""
if (df.index[1] - df.index[0]) > datetime.timedelta(days=1):
raise ValueError("Frequency of data should be no more than daily.")
cols = df.columns.get_level_values("segment").unique()
encoded_matrix = np.array([int(x in self.holidays) for x in df.index])
encoded_matrix = encoded_matrix.reshape(-1, 1).repeat(len(cols), axis=1)
encoded_df = pd.DataFrame(
encoded_matrix,
columns=pd.MultiIndex.from_product([cols, [self.out_column]], names=("segment", "feature")),
index=df.index,
)
encoded_df = encoded_df.astype("category")
df = df.join(encoded_df)
df = df.sort_index(axis=1)
return df