From 5d364f9af0d42e86727bfb7bc365b7dda21ac2e9 Mon Sep 17 00:00:00 2001 From: ahuizxc Date: Sun, 29 Sep 2019 16:49:40 +0800 Subject: [PATCH] fix crf layer mask tensor type ERROR when concatenate --- keras_contrib/layers/crf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_contrib/layers/crf.py b/keras_contrib/layers/crf.py index 88a64ac69..3d43537d4 100644 --- a/keras_contrib/layers/crf.py +++ b/keras_contrib/layers/crf.py @@ -513,6 +513,7 @@ def recursion(self, input_energy, mask=None, go_backwards=False, constants = [chain_energy] if mask is not None: + mask = K.cast(mask, K.floatx()) mask2 = K.cast(K.concatenate([mask, K.zeros_like(mask[:, :1])], axis=1), K.floatx()) constants.append(mask2)