diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasReLU.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasReLU.java index 14c4b3d73..3a30ec9ef 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasReLU.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasReLU.java @@ -63,10 +63,10 @@ public class KerasReLU extends KerasLayer { double negativeSlope = 0.0; double threshold = 0.0; if (innerConfig.containsKey("negative_slope")) { - negativeSlope = (double) innerConfig.get("negative_slope"); + negativeSlope = ((Number)innerConfig.get("negative_slope")).doubleValue(); } if (innerConfig.containsKey("threshold")) { - threshold = (double) innerConfig.get("threshold"); + threshold = ((Number)innerConfig.get("threshold")).doubleValue(); } this.layer = new ActivationLayer.Builder().name(this.layerName) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalization.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalization.java index 7f7d8dc4c..c27f753d1 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalization.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalization.java @@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; @@ -351,6 +352,10 @@ public class KerasBatchNormalization extends KerasLayer { private int getBatchNormAxis(Map layerConfig) throws InvalidKerasConfigurationException { Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); - return (int) innerConfig.get(LAYER_FIELD_AXIS); + Object batchNormAxis = innerConfig.get(LAYER_FIELD_AXIS); + if (batchNormAxis instanceof List){ + return ((Number)((List)batchNormAxis).get(0)).intValue(); + } + return ((Number)innerConfig.get(LAYER_FIELD_AXIS)).intValue(); } }