减轻时间序列预测中分布偏移模块Dish-TS(使用示例)

前言

  • 前面我解读了论文:减轻时间序列预测中分布偏移的一般范式Dish - TS,论文解读系列,论文地址,GitHub项目地址
  • 由于该模块可以与任何时间序列预测的深度模型结合,并且取得了比ReVIN更好的效果,这里我根据示例演示该模块的使用方法。

    Dish - TS模块

import torch
import torch.nn as nn
import torch.nn.functional as F


class DishTS(nn.Module):
    def __init__(self, args):
        super().__init__()
        init = args.dish_init #'standard', 'avg' or 'uniform'
        activate = True
        n_series = args.n_series # number of series
        lookback = args.seq_len # lookback length
        if init == 'standard':
            self.reduce_mlayer = nn.Parameter(torch.rand(n_series, lookback, 2)/lookback)
        elif init == 'avg':
            self.reduce_mlayer = nn.Parameter(torch.ones(n_series, lookback, 2)/lookback)
        elif init == 'uniform':
            self.reduce_mlayer = nn.Parameter(torch.ones(n_series, lookback, 2)/lookback+torch.rand(n_series, lookback, 2)/lookback)
        self.gamma, self.beta = nn.Parameter(torch.ones(n_series)), nn.Parameter(torch.zeros(n_series))
        self.activate = activate

    def forward(self, batch_x, mode='forward', dec_inp=None):
        if mode == 'forward':
            # batch_x: B*L*D || dec_inp: B*?*D (for xxformers)
            self.preget(batch_x)
            batch_x = self.forward_process(batch_x)
            dec_inp = None if dec_inp is None else self.forward_process(dec_inp)
            return batch_x, dec_inp
        elif mode == 'inverse':
            # batch_x: B*H*D (forecasts)
            batch_y = self.inverse_process(batch_x)
            return batch_y

    def preget(self, batch_x):
        x_transpose = batch_x.permute(2,0,1) 
        theta = torch.bmm(x_transpose, self.reduce_mlayer).permute(1,2,0)
        if self.activate:
            theta = F.gelu(theta)
        self.phil, self.phih = theta[:,:1,:], theta[:,1:,:] 
        self.xil = torch.sum(torch.pow(batch_x - self.phil,2), axis=1, keepdim=True) / (batch_x.shape[1]-1)
        self.xih = torch.sum(torch.pow(batch_x - self.phih,2), axis=1, keepdim=True) / (batch_x.shape[1]-1)

    def forward_process(self, batch_input):
        #print(batch_input.shape, self.phil.shape, self.xih.shape)
        temp = (batch_input - self.phil)/torch.sqrt(self.xil + 1e-8)
        rst = temp.mul(self.gamma) + self.beta
        return rst

    def inverse_process(self, batch_input):
        return ((batch_input - self.beta) / self.gamma) * torch.sqrt(self.xih + 1e-8) + self.phih

ReVIN模块

import torch
import torch.nn as nn


class RevIN(nn.Module):
    def __init__(self, args):
        super().__init__()
        if args.affine: # args.affine: use affine layers or not
            self.gamma = nn.Parameter(torch.ones(args.n_series)) # args.n_series: number of series
            self.beta = nn.Parameter(torch.zeros(args.n_series))
        else:
            self.gamma, self.beta = 1, 0

    def forward(self, batch_x, mode='forward', dec_inp=None):
        if mode == 'forward':
            # batch_x: B*L*D || dec_inp: B*?*D (for xxformers)
            self.preget(batch_x)
            batch_x = self.forward_process(batch_x)
            dec_inp = None if dec_inp is None else self.forward_process(dec_inp)
            return batch_x, dec_inp
        elif mode == 'inverse':
            # batch_x: B*H*D (forecasts)
            batch_y = self.inverse_process(batch_x)
            return batch_y

    def preget(self, batch_x):
        self.avg = torch.mean(batch_x, axis=1, keepdim=True).detach() # b*1*d
        self.var = torch.var(batch_x, axis=1, keepdim=True).detach()  # b*1*d

    def forward_process(self, batch_input):
        temp = (batch_input - self.avg)/torch.sqrt(self.var + 1e-8)
        return temp.mul(self.gamma) + self.beta

    def inverse_process(self, batch_input):
        return ((batch_input - self.beta) / self.gamma) * torch.sqrt(self.var + 1e-8) + self.avg

使用示例

假如现在我有一个非常简单的NLinear模型,x的输入维度是[batch,Input length,Channel]

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Module(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len

        self.channels = configs.enc_in
        self.individual = configs.individual
        if self.individual:
            self.Linear = nn.ModuleList()
            for i in range(self.channels):
                self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
        else:
            self.Linear = nn.Linear(self.seq_len, self.pred_len)

    def forward(self, x):
        # x: [Batch, Input length, Channel]
        seq_last = x[:,-1:,:].detach()
        x = x - seq_last
        if self.individual:
            output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
            for i in range(self.channels):
                output[:,:,i] = self.Linear[i](x[:,:,i])
            x = output
        else:
            x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
        x = x + seq_last
        return x # [Batch, Output length, Channel]
  • 我只需要在x输入模型之前使用Dish-TS模块,在模型输出结果后再对其进行反归一化即可。
import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, args, forecast_model, norm_model):
        super().__init__()
        self.args = args
        self.fm = forecast_model
        self.nm = norm_model

    def forward(self, batch_x, dec_inp):
        if self.nm is not None:
            batch_x, dec_inp = self.nm(batch_x, 'forward', dec_inp)

        if 'former' in self.args.model:
            forecast = self.fm(batch_x, None, dec_inp, None)
        else:
            forecast = self.fm(batch_x)

        if self.nm is not None:
            forecast = self.nm(forecast, 'inverse')

        return forecast
  • 在使用时,我只需要写:
    parser.add_argument('--norm', type=str, default='none') # none, revin, dishts
    parser.add_argument('--affine', type=str, default=1) # revin
    parser.add_argument('--dish_init', type=str, default='standard') # standard, 'avg' or 'uniform'
    # args为参数列表
    model = Model(args, 'NLinear', 'dishts')
  • 关于dish_init这个参数,作者在论文中通过实验表明’standard’、’avg’、’uniform’这3种初始化方法,对Dish-TS效果的影响都不怎么显著,当然,这也可能是数据的问题,这里建议大家可以都试试,挑选效果最好的初始化方法。