Skip to content

Commit 9ebc972

Browse files
dwkk-googlepartheakweinmeister
authored
docs(samples): replace deprecated fields in create_training_pipeline_tabular_forecasting_sample.py (#981)
* Update create_training_pipeline_tabular_forecasting_sample.py * Update create_training_pipeline_tabular_forecasting_sample_test.py Co-authored-by: Anthonios Partheniou <partheniou@google.com> Co-authored-by: Karl Weinmeister <11586922+kweinmeister@users.noreply.github.com>
1 parent ea16849 commit 9ebc972

2 files changed

Lines changed: 14 additions & 14 deletions

File tree

samples/snippets/pipeline_service/create_training_pipeline_tabular_forecasting_sample.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def create_training_pipeline_tabular_forecasting_sample(
2626
target_column: str,
2727
time_series_identifier_column: str,
2828
time_column: str,
29-
static_columns: str,
30-
time_variant_past_only_columns: str,
31-
time_variant_past_and_future_columns: str,
32-
forecast_window_end: int,
29+
time_series_attribute_columns: str,
30+
unavailable_at_forecast: str,
31+
available_at_forecast: str,
32+
forecast_horizon: int,
3333
location: str = "us-central1",
3434
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
3535
):
@@ -47,7 +47,7 @@ def create_training_pipeline_tabular_forecasting_sample(
4747
{"auto": {"column_name": "deaths"}},
4848
]
4949

50-
period = {"unit": "day", "quantity": 1}
50+
data_granularity = {"unit": "day", "quantity": 1}
5151

5252
# the inputs should be formatted according to the training_task_definition yaml file
5353
training_task_inputs_dict = {
@@ -56,13 +56,13 @@ def create_training_pipeline_tabular_forecasting_sample(
5656
"timeSeriesIdentifierColumn": time_series_identifier_column,
5757
"timeColumn": time_column,
5858
"transformations": transformations,
59-
"period": period,
59+
"dataGranularity": data_granularity,
6060
"optimizationObjective": "minimize-rmse",
6161
"trainBudgetMilliNodeHours": 8000,
62-
"staticColumns": static_columns,
63-
"timeVariantPastOnlyColumns": time_variant_past_only_columns,
64-
"timeVariantPastAndFutureColumns": time_variant_past_and_future_columns,
65-
"forecastWindowEnd": forecast_window_end,
62+
"timeSeriesAttributeColumns": time_series_attribute_columns,
63+
"unavailableAtForecast": unavailable_at_forecast,
64+
"availableAtForecast": available_at_forecast,
65+
"forecastHorizon": forecast_horizon,
6666
}
6767

6868
training_task_inputs = json_format.ParseDict(training_task_inputs_dict, Value())

samples/snippets/pipeline_service/create_training_pipeline_tabular_forecasting_sample_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ def test_ucaip_generated_create_training_pipeline_sample(capsys, shared_state):
7777
target_column=TARGET_COLUMN,
7878
time_series_identifier_column="county",
7979
time_column="date",
80-
static_columns=["state_name"],
81-
time_variant_past_only_columns=["deaths"],
82-
time_variant_past_and_future_columns=["date"],
83-
forecast_window_end=10,
80+
time_series_attribute_columns=["state_name"],
81+
unavailable_at_forecast=["deaths"],
82+
available_at_forecast=["date"],
83+
forecast_horizon=10,
8484
)
8585

8686
out, _ = capsys.readouterr()

0 commit comments

Comments
 (0)