diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java
index 32df3e69d..8c80e3bb4 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java
@@ -509,7 +509,7 @@ public abstract class DifferentialFunction {
* @return the arguments for a given function
*/
public SDVariable[] args() {
- return sameDiff.getInputVariablesForOp(this);
+ return sameDiff == null ? null : sameDiff.getInputVariablesForOp(this);
}
/**
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/OpBenchmarkListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/OpBenchmarkListener.java
new file mode 100644
index 000000000..103b0f960
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/OpBenchmarkListener.java
@@ -0,0 +1,189 @@
+package org.nd4j.autodiff.listeners.debugging;
+
+import lombok.*;
+import org.nd4j.autodiff.listeners.At;
+import org.nd4j.autodiff.listeners.BaseListener;
+import org.nd4j.autodiff.listeners.Operation;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.autodiff.samediff.internal.SameDiffOp;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.DynamicCustomOp;
+import org.nd4j.linalg.api.ops.Op;
+import org.nd4j.linalg.dataset.api.MultiDataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.util.ArrayUtil;
+
+import java.text.DecimalFormat;
+import java.util.*;
+
+/**
+ * A simple listener for benchmarking single operations in SameDiff
+ * Supports 2 modes:
+ * - SINGLE_ITER_PRINT: Print the runtime of the first iteration
+ * - AGGREGATE: Collect statistics for multiple runs, that can be accessed (by op name) via {@link #getAggregateModeMap()}
+ *
+ * @author Alex Black
+ */
+@Getter
+public class OpBenchmarkListener extends BaseListener {
+
+ public enum Mode {SINGLE_ITER_PRINT, AGGREGATE}
+
+ private final Operation operation;
+ private final Mode mode;
+ private final long minRuntime;
+ private Map aggregateModeMap;
+
+ @Getter(AccessLevel.PRIVATE)
+ private long start;
+ @Getter(AccessLevel.PRIVATE)
+ private boolean printActive;
+ private boolean printDone;
+
+ public OpBenchmarkListener(Operation operation, @NonNull Mode mode) {
+ this(operation, mode, 0);
+ }
+
+ /**
+ * @param operation Operation to collect stats for
+ * @param mode Mode - see {@link OpBenchmarkListener}
+ * @param minRuntime Minimum runtime - only applies to Mode.SINGLE_ITER_PRINT. If op runtime below this: don't print
+ */
+ public OpBenchmarkListener(Operation operation, @NonNull Mode mode, long minRuntime) {
+ this.operation = operation;
+ this.mode = mode;
+ this.minRuntime = minRuntime;
+ }
+
+ @Override
+ public boolean isActive(Operation operation) {
+ return this.operation == null || this.operation == operation;
+ }
+
+ @Override
+ public void operationStart(SameDiff sd, Operation op) {
+ if(printDone)
+ return;
+ if(this.operation == null || this.operation == op)
+ printActive = true;
+ }
+
+ @Override
+ public void operationEnd(SameDiff sd, Operation op) {
+ if(printDone)
+ return;
+ if(this.operation == null || this.operation == op) {
+ printActive = false;
+ printDone = true;
+ }
+ }
+
+ @Override
+ public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
+ start = System.currentTimeMillis();
+ }
+
+ @Override
+ public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
+ long now = System.currentTimeMillis();
+
+ if (mode == Mode.SINGLE_ITER_PRINT && printActive && (now-start) > this.minRuntime) {
+ System.out.println(getOpString(op, now));
+ } else if (mode == Mode.AGGREGATE) {
+ if(aggregateModeMap == null)
+ aggregateModeMap = new LinkedHashMap<>();
+
+ if(!aggregateModeMap.containsKey(op.getName())){
+ String s = getOpString(op, null);
+ OpExec oe = new OpExec(op.getName(), op.getOp().opName(), op.getOp().getClass(),
+ new ArrayList(), s);
+ aggregateModeMap.put(op.getName(), oe);
+ }
+
+ aggregateModeMap.get(op.getName()).getRuntimeMs().add(now-start);
+ }
+ }
+
+ private String getOpString(SameDiffOp op, Long now){
+ StringBuilder sb = new StringBuilder();
+ sb.append(op.getName()).append(" - ").append(op.getOp().getClass().getSimpleName())
+ .append("(").append(op.getOp().opName()).append(") - ");
+ if(now != null) {
+ sb.append(now - start).append(" ms\n");
+ }
+
+ if (op.getOp() instanceof DynamicCustomOp) {
+ DynamicCustomOp dco = (DynamicCustomOp) op.getOp();
+ int x = 0;
+
+ for (INDArray i : dco.inputArguments()) {
+ sb.append(" in ").append(x++).append(": ").append(i.shapeInfoToString()).append("\n");
+ }
+ x = 0;
+ for (INDArray o : dco.outputArguments()) {
+ sb.append(" out ").append(x++).append(": ").append(o.shapeInfoToString()).append("\n");
+ }
+ long[] iargs = dco.iArgs();
+ boolean[] bargs = dco.bArgs();
+ double[] targs = dco.tArgs();
+ if (iargs != null && iargs.length > 0) {
+ sb.append(" iargs: ").append(Arrays.toString(iargs)).append("\n");
+ }
+ if (bargs != null && bargs.length > 0) {
+ sb.append(" bargs: ").append(Arrays.toString(bargs)).append("\n");
+ }
+ if (targs != null && targs.length > 0) {
+ sb.append(" targs: ").append(Arrays.toString(targs)).append("\n");
+ }
+ } else {
+ Op o = (Op) op.getOp();
+ if (o.x() != null)
+ sb.append(" x: ").append(o.x().shapeInfoToString());
+ if (o.y() != null)
+ sb.append(" y: ").append(o.y().shapeInfoToString());
+ if (o.z() != null)
+ sb.append(" z: ").append(o.z().shapeInfoToString());
+ }
+ return sb.toString();
+ }
+
+
+ @AllArgsConstructor
+ @Data
+ public static class OpExec {
+ private final String opOwnName;
+ private final String opName;
+ private final Class> opClass;
+ private List runtimeMs;
+ private String firstIter;
+
+ @Override
+ public String toString(){
+ DecimalFormat df = new DecimalFormat("0.000");
+
+ return opOwnName + " - op class: " + opClass.getSimpleName() + " (op name: " + opName + ")\n"
+ + "count: " + runtimeMs.size() + ", mean: " + df.format(avgMs()) + "ms, std: " + df.format(stdMs()) + "ms, min: " + minMs() + "ms, max: " + maxMs() + "ms\n"
+ + firstIter;
+ }
+
+ public double avgMs() {
+ long sum = 0;
+ for (Long l : runtimeMs) {
+ sum += l;
+ }
+ return sum / (double) runtimeMs.size();
+ }
+
+ public double stdMs() {
+ return Nd4j.createFromArray(ArrayUtil.toArrayLong(runtimeMs)).stdNumber().doubleValue();
+ }
+
+ public long minMs() {
+ return Nd4j.createFromArray(ArrayUtil.toArrayLong(runtimeMs)).minNumber().longValue();
+ }
+
+ public long maxMs() {
+ return Nd4j.createFromArray(ArrayUtil.toArrayLong(runtimeMs)).maxNumber().longValue();
+ }
+ }
+}