当前位置:  开发笔记 > 编程语言 > 正文

LSTM - 对部分序列进行预测

如何解决《LSTM-对部分序列进行预测》经验,为你挑选了0个好方法。

这个问题继续我之前提出的一个问题.

我已经训练了一个LSTM模型来预测100个样本的批次的二进制类(1或0),每个样本有3个特征,即:数据的形状是(m,100,3),其中m是批次的数量.

数据:

[
    [[1,2,3],[1,2,3]... 100 sampels],
    [[1,2,3],[1,2,3]... 100 sampels],
    ... avaialble batches in the training data
]

目标:

[
   [1]
   [0]
   ...
]

型号代码:

def build_model(num_samples, num_features, is_training):
    model = Sequential()
    opt = optimizers.Adam(lr=0.0005, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0001)

    batch_size = None if is_training else 1
    stateful = False if is_training else True
    first_lstm = LSTM(32, batch_input_shape=(batch_size, num_samples, num_features), return_sequences=True,
                      activation='tanh', stateful=stateful)

    model.add(first_lstm)
    model.add(LeakyReLU())
    model.add(Dropout(0.2))
    model.add(LSTM(16, return_sequences=True, activation='tanh', stateful=stateful))
    model.add(Dropout(0.2))
    model.add(LeakyReLU())
    model.add(LSTM(8, return_sequences=False, activation='tanh', stateful=stateful))
    model.add(LeakyReLU())
    model.add(Dense(1, activation='sigmoid'))

    if is_training:
        model.compile(loss='binary_crossentropy', optimizer=opt,
                      metrics=['accuracy', keras_metrics.precision(), keras_metrics.recall(), f1])
    return model

对于训练阶段,模型不是有状态的.在预测我正在使用有状态模型时,迭代数据并输出每个样本的概率:

for index, row in data.iterrows():
    if index % 100 == 0:
        predicting_model.reset_states()
    vals = np.array([[row[['a', 'b', 'c']].values]])
    prob = predicting_model.predict_on_batch(vals)

当查看批处理结束时的概率时,它正是我用整个批处理预测时得到的值(不是一个接一个).但是,我预计当新样本到达时,概率将始终在正确的方向上继续.实际发生的是,概率输出可能会在任意样本上出现错误的类别(见下文).


预测时100个样品批次的两个样品(标签= 1):

在此输入图像描述

和Label = 0: 在此输入图像描述

有没有办法实现我想要的(避免极端尖峰,同时预测概率),或者这是一个给定的事实?

任何解释,建议将不胜感激.


更新 感谢@today建议,我尝试使用return_sequence=True最后一个LSTM层在每个输入时间步骤中使用隐藏状态输出训练网络.

所以现在标签看起来像这样(形状(100,100)):

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
...]

模型摘要:

Layer (type)                 Output Shape              Param #   
=================================================================
lstm_1 (LSTM)                (None, 100, 32)           4608      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 100, 32)           0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 100, 32)           0         
_________________________________________________________________
lstm_2 (LSTM)                (None, 100, 16)           3136      
_________________________________________________________________
dropout_2 (Dropout)          (None, 100, 16)           0         
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 100, 16)           0         
_________________________________________________________________
lstm_3 (LSTM)                (None, 100, 8)            800       
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 100, 8)            0         
_________________________________________________________________
dense_1 (Dense)              (None, 100, 1)            9         
=================================================================
Total params: 8,553
Trainable params: 8,553
Non-trainable params: 0
_________________________________________________________________

但是,我得到一个例外:

ValueError: Error when checking target: expected dense_1 to have 3 dimensions, but got array with shape (75, 100)

我需要修理什么?

推荐阅读
mobiledu2402851173
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有