package weka.classifiers.pmml.consumer;

import at.tugraz.ist.spreadsheet.gui.panel.spreadsheet.worksheet.cell.ContentCellPanel;
import java.io.Serializable;
import java.util.ArrayList;
import org.apache.commons.lang3.StringUtils;
import org.apache.xmlbeans.impl.jam.xml.JamXmlElements;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.matrix.Maths;
import weka.core.pmml.MiningSchema;
import weka.core.pmml.TargetMetaInfo;

/* loaded from: input_file:weka/classifiers/pmml/consumer/Regression.class */
public class Regression extends PMMLClassifier implements Serializable {
    private static final long serialVersionUID = -5551125528409488634L;
    protected String m_algorithmName;
    protected RegressionTable[] m_regressionTables;
    protected Normalization m_normalizationMethod;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:weka/classifiers/pmml/consumer/Regression$Normalization.class */
    public enum Normalization {
        NONE,
        SIMPLEMAX,
        SOFTMAX,
        LOGIT,
        PROBIT,
        CLOGLOG,
        EXP,
        LOGLOG,
        CAUCHIT
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:weka/classifiers/pmml/consumer/Regression$RegressionTable.class */
    public static class RegressionTable implements Serializable {
        private static final long serialVersionUID = -5259866093996338995L;
        public static final int REGRESSION = 0;
        public static final int CLASSIFICATION = 1;
        protected int m_functionType;
        protected MiningSchema m_miningSchema;
        protected double m_intercept;
        protected int m_targetCategory;
        protected ArrayList<Predictor> m_predictors = new ArrayList<>();
        protected ArrayList<PredictorTerm> m_predictorTerms = new ArrayList<>();

        /* loaded from: input_file:weka/classifiers/pmml/consumer/Regression$RegressionTable$CategoricalPredictor.class */
        protected class CategoricalPredictor extends Predictor {
            private static final long serialVersionUID = 3077920125549906819L;
            protected String m_valueName;
            protected int m_valueIndex;

            protected CategoricalPredictor(Element element, Instances instances) throws Exception {
                super(element, instances);
                this.m_valueIndex = -1;
                String attribute = element.getAttribute("value");
                if (attribute.length() == 0) {
                    throw new Exception("[CategoricalPredictor] attribute value not specified!");
                }
                this.m_valueName = attribute;
                Attribute attribute2 = instances.attribute(this.m_miningSchemaAttIndex);
                if (attribute2.isString()) {
                    attribute2.addStringValue(this.m_valueName);
                }
                this.m_valueIndex = attribute2.indexOfValue(this.m_valueName);
                if (this.m_valueIndex == -1) {
                    throw new Exception("[CategoricalPredictor] unable to find value " + this.m_valueName + " in mining schema attribute " + attribute2.name());
                }
            }

            @Override // weka.classifiers.pmml.consumer.Regression.RegressionTable.Predictor
            public String toString() {
                return super.toString() + this.m_name + ContentCellPanel.FORMULA_DELIMITER + this.m_valueName;
            }

            @Override // weka.classifiers.pmml.consumer.Regression.RegressionTable.Predictor
            public void add(double[] dArr, double[] dArr2) {
                if (this.m_valueIndex == ((int) dArr2[this.m_miningSchemaAttIndex])) {
                    if (RegressionTable.this.m_targetCategory == -1) {
                        dArr[0] = dArr[0] + this.m_coefficient;
                    } else {
                        int i = RegressionTable.this.m_targetCategory;
                        dArr[i] = dArr[i] + this.m_coefficient;
                    }
                }
            }
        }

        /* loaded from: input_file:weka/classifiers/pmml/consumer/Regression$RegressionTable$NumericPredictor.class */
        protected class NumericPredictor extends Predictor {
            private static final long serialVersionUID = -4335075205696648273L;
            protected double m_exponent;

            protected NumericPredictor(Element element, Instances instances) throws Exception {
                super(element, instances);
                this.m_exponent = 1.0d;
                String attribute = element.getAttribute("exponent");
                if (attribute.length() > 0) {
                    this.m_exponent = Double.parseDouble(attribute);
                }
            }

            @Override // weka.classifiers.pmml.consumer.Regression.RegressionTable.Predictor
            public String toString() {
                String str = super.toString() + this.m_name;
                if (this.m_exponent > 1.0d || this.m_exponent < 1.0d) {
                    str = str + "^" + Utils.doubleToString(this.m_exponent, 4);
                }
                return str;
            }

            @Override // weka.classifiers.pmml.consumer.Regression.RegressionTable.Predictor
            public void add(double[] dArr, double[] dArr2) {
                if (RegressionTable.this.m_targetCategory == -1) {
                    dArr[0] = dArr[0] + (this.m_coefficient * Math.pow(dArr2[this.m_miningSchemaAttIndex], this.m_exponent));
                } else {
                    int i = RegressionTable.this.m_targetCategory;
                    dArr[i] = dArr[i] + (this.m_coefficient * Math.pow(dArr2[this.m_miningSchemaAttIndex], this.m_exponent));
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:weka/classifiers/pmml/consumer/Regression$RegressionTable$Predictor.class */
        public static abstract class Predictor implements Serializable {
            private static final long serialVersionUID = 7043831847273383618L;
            protected String m_name;
            protected int m_miningSchemaAttIndex;
            protected double m_coefficient;

            protected Predictor(Element element, Instances instances) throws Exception {
                this.m_miningSchemaAttIndex = -1;
                this.m_coefficient = 1.0d;
                this.m_name = element.getAttribute("name");
                for (int i = 0; i < instances.numAttributes(); i++) {
                    if (instances.attribute(i).name().equals(this.m_name)) {
                        this.m_miningSchemaAttIndex = i;
                    }
                }
                if (this.m_miningSchemaAttIndex == -1) {
                    throw new Exception("[Predictor] unable to find matching attribute for predictor " + this.m_name);
                }
                String attribute = element.getAttribute("coefficient");
                if (attribute.length() > 0) {
                    this.m_coefficient = Double.parseDouble(attribute);
                }
            }

            public String toString() {
                return Utils.doubleToString(this.m_coefficient, 12, 4) + " * ";
            }

            public abstract void add(double[] dArr, double[] dArr2);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* loaded from: input_file:weka/classifiers/pmml/consumer/Regression$RegressionTable$PredictorTerm.class */
        public class PredictorTerm implements Serializable {
            private static final long serialVersionUID = 5493100145890252757L;
            protected double m_coefficient;
            protected int[] m_indexes;
            protected String[] m_fieldNames;

            protected PredictorTerm(Element element, Instances instances) throws Exception {
                String attribute;
                this.m_coefficient = 1.0d;
                String attribute2 = element.getAttribute("coefficient");
                if (attribute2 != null && attribute2.length() > 0) {
                    try {
                        this.m_coefficient = Double.parseDouble(attribute2);
                    } catch (IllegalArgumentException e) {
                        throw new Exception("[PredictorTerm] unable to parse coefficient");
                    }
                }
                NodeList elementsByTagName = element.getElementsByTagName("FieldRef");
                if (elementsByTagName.getLength() > 0) {
                    this.m_indexes = new int[elementsByTagName.getLength()];
                    this.m_fieldNames = new String[elementsByTagName.getLength()];
                    for (int i = 0; i < elementsByTagName.getLength(); i++) {
                        Node item = elementsByTagName.item(i);
                        if (item.getNodeType() == 1 && (attribute = ((Element) item).getAttribute(JamXmlElements.FIELD)) != null && attribute.length() > 0) {
                            boolean z = false;
                            int i2 = 0;
                            while (true) {
                                if (i2 >= instances.numAttributes()) {
                                    break;
                                }
                                if (!instances.attribute(i2).name().equals(attribute)) {
                                    i2++;
                                } else {
                                    if (!instances.attribute(i2).isNumeric()) {
                                        throw new Exception("[PredictorTerm] field is not continuous: " + attribute);
                                    }
                                    z = true;
                                    this.m_indexes[i] = i2;
                                    this.m_fieldNames[i] = attribute;
                                }
                            }
                            if (!z) {
                                throw new Exception("[PredictorTerm] Unable to find field " + attribute + " in mining schema!");
                            }
                        }
                    }
                }
            }

            public String toString() {
                StringBuffer stringBuffer = new StringBuffer();
                stringBuffer.append("(" + Utils.doubleToString(this.m_coefficient, 12, 4));
                for (int i = 0; i < this.m_fieldNames.length; i++) {
                    stringBuffer.append(" * " + this.m_fieldNames[i]);
                }
                stringBuffer.append(")");
                return stringBuffer.toString();
            }

            public void add(double[] dArr, double[] dArr2) {
                int i = RegressionTable.this.m_targetCategory != -1 ? RegressionTable.this.m_targetCategory : 0;
                double d = this.m_coefficient;
                for (int i2 = 0; i2 < this.m_indexes.length; i2++) {
                    d *= dArr2[this.m_indexes[i2]];
                }
                int i3 = i;
                dArr[i3] = dArr[i3] + d;
            }
        }

        public String toString() {
            Instances fieldsAsInstances = this.m_miningSchema.getFieldsAsInstances();
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append("Regression table:\n");
            stringBuffer.append(fieldsAsInstances.classAttribute().name());
            if (this.m_functionType == 1) {
                stringBuffer.append(ContentCellPanel.FORMULA_DELIMITER + fieldsAsInstances.classAttribute().value(this.m_targetCategory));
            }
            stringBuffer.append(" =\n\n");
            for (int i = 0; i < this.m_predictors.size(); i++) {
                stringBuffer.append(this.m_predictors.get(i).toString() + " +\n");
            }
            for (int i2 = 0; i2 < this.m_predictorTerms.size(); i2++) {
                stringBuffer.append(this.m_predictorTerms.get(i2).toString() + " +\n");
            }
            stringBuffer.append(Utils.doubleToString(this.m_intercept, 12, 4));
            stringBuffer.append("\n\n");
            return stringBuffer.toString();
        }

        protected RegressionTable(Element element, int i, MiningSchema miningSchema) throws Exception {
            this.m_functionType = 0;
            this.m_intercept = KStarConstants.FLOOR;
            this.m_targetCategory = -1;
            this.m_miningSchema = miningSchema;
            this.m_functionType = i;
            Instances fieldsAsInstances = this.m_miningSchema.getFieldsAsInstances();
            String attribute = element.getAttribute("intercept");
            if (attribute.length() > 0) {
                this.m_intercept = Double.parseDouble(attribute);
            }
            if (this.m_functionType == 1) {
                String attribute2 = element.getAttribute("targetCategory");
                if (attribute2.length() > 0) {
                    Attribute classAttribute = fieldsAsInstances.classAttribute();
                    for (int i2 = 0; i2 < classAttribute.numValues(); i2++) {
                        if (classAttribute.value(i2).equals(attribute2)) {
                            this.m_targetCategory = i2;
                        }
                    }
                }
                if (this.m_targetCategory == -1) {
                    throw new Exception("[RegressionTable] No target categories defined for classification");
                }
            }
            NodeList elementsByTagName = element.getElementsByTagName("NumericPredictor");
            for (int i3 = 0; i3 < elementsByTagName.getLength(); i3++) {
                Node item = elementsByTagName.item(i3);
                if (item.getNodeType() == 1) {
                    this.m_predictors.add(new NumericPredictor((Element) item, fieldsAsInstances));
                }
            }
            NodeList elementsByTagName2 = element.getElementsByTagName("CategoricalPredictor");
            for (int i4 = 0; i4 < elementsByTagName2.getLength(); i4++) {
                Node item2 = elementsByTagName2.item(i4);
                if (item2.getNodeType() == 1) {
                    this.m_predictors.add(new CategoricalPredictor((Element) item2, fieldsAsInstances));
                }
            }
            NodeList elementsByTagName3 = element.getElementsByTagName("PredictorTerm");
            for (int i5 = 0; i5 < elementsByTagName3.getLength(); i5++) {
                this.m_predictorTerms.add(new PredictorTerm((Element) elementsByTagName3.item(i5), fieldsAsInstances));
            }
        }

        public void predict(double[] dArr, double[] dArr2) {
            if (this.m_targetCategory == -1) {
                dArr[0] = this.m_intercept;
            } else {
                dArr[this.m_targetCategory] = this.m_intercept;
            }
            for (int i = 0; i < this.m_predictors.size(); i++) {
                this.m_predictors.get(i).add(dArr, dArr2);
            }
            for (int i2 = 0; i2 < this.m_predictorTerms.size(); i2++) {
                this.m_predictorTerms.get(i2).add(dArr, dArr2);
            }
        }
    }

    public Regression(Element element, Instances instances, MiningSchema miningSchema) throws Exception {
        super(instances, miningSchema);
        int i;
        this.m_normalizationMethod = Normalization.NONE;
        String attribute = element.getAttribute("functionName");
        if (attribute.equals("regression")) {
            i = 0;
        } else {
            if (!attribute.equals("classification")) {
                throw new Exception("[PMML Regression] Function name not defined in pmml!");
            }
            i = 1;
        }
        String attribute2 = element.getAttribute("algorithmName");
        if (attribute2 != null && attribute2.length() > 0) {
            this.m_algorithmName = attribute2;
        }
        this.m_normalizationMethod = determineNormalization(element);
        setUpRegressionTables(element, i);
    }

    private void setUpRegressionTables(Element element, int i) throws Exception {
        NodeList elementsByTagName = element.getElementsByTagName("RegressionTable");
        if (elementsByTagName.getLength() == 0) {
            throw new Exception("[Regression] no regression tables defined!");
        }
        this.m_regressionTables = new RegressionTable[elementsByTagName.getLength()];
        for (int i2 = 0; i2 < elementsByTagName.getLength(); i2++) {
            Node item = elementsByTagName.item(i2);
            if (item.getNodeType() == 1) {
                this.m_regressionTables[i2] = new RegressionTable((Element) item, i, this.m_miningSchema);
            }
        }
    }

    private static Normalization determineNormalization(Element element) {
        Normalization normalization = Normalization.NONE;
        String attribute = element.getAttribute("normalizationMethod");
        if (attribute.equals("simplemax")) {
            normalization = Normalization.SIMPLEMAX;
        } else if (attribute.equals("softmax")) {
            normalization = Normalization.SOFTMAX;
        } else if (attribute.equals("logit")) {
            normalization = Normalization.LOGIT;
        } else if (attribute.equals("probit")) {
            normalization = Normalization.PROBIT;
        } else if (attribute.equals("cloglog")) {
            normalization = Normalization.CLOGLOG;
        } else if (attribute.equals("exp")) {
            normalization = Normalization.EXP;
        } else if (attribute.equals("loglog")) {
            normalization = Normalization.LOGLOG;
        } else if (attribute.equals("cauchit")) {
            normalization = Normalization.CAUCHIT;
        }
        return normalization;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("PMML version " + getPMMLVersion());
        if (!getCreatorApplication().equals("?")) {
            stringBuffer.append("\nApplication: " + getCreatorApplication());
        }
        if (this.m_algorithmName != null) {
            stringBuffer.append("\nPMML Model: " + this.m_algorithmName);
        }
        stringBuffer.append("\n\n");
        stringBuffer.append(this.m_miningSchema);
        for (RegressionTable regressionTable : this.m_regressionTables) {
            stringBuffer.append(regressionTable);
        }
        if (this.m_normalizationMethod != Normalization.NONE) {
            stringBuffer.append("Normalization: " + this.m_normalizationMethod);
        }
        stringBuffer.append(StringUtils.LF);
        return stringBuffer.toString();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (!this.m_initialized) {
            mapToMiningSchema(instance.dataset());
        }
        double[] dArr = this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric() ? new double[1] : new double[this.m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
        double[] instanceToSchema = this.m_fieldsMap.instanceToSchema(instance, this.m_miningSchema);
        boolean z = false;
        int i = 0;
        while (true) {
            if (i >= instanceToSchema.length) {
                break;
            }
            if (i != this.m_miningSchema.getFieldsAsInstances().classIndex() && Utils.isMissingValue(instanceToSchema[i])) {
                z = true;
                break;
            }
            i++;
        }
        if (z) {
            if (!this.m_miningSchema.hasTargetMetaData()) {
                String str = "[Regression] WARNING: Instance to predict has missing value(s) but there is no missing value handling meta data and no prior probabilities/default value to fall back to. No prediction will be made (" + ((this.m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() || this.m_miningSchema.getFieldsAsInstances().classAttribute().isString()) ? "zero probabilities output)." : "NaN output).");
                if (this.m_log == null) {
                    System.err.println(str);
                } else {
                    this.m_log.logMessage(str);
                }
                if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                    dArr[0] = Utils.missingValue();
                }
                return dArr;
            }
            TargetMetaInfo targetMetaData = this.m_miningSchema.getTargetMetaData();
            if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                dArr[0] = targetMetaData.getDefaultValue();
            } else {
                Instances fieldsAsInstances = this.m_miningSchema.getFieldsAsInstances();
                for (int i2 = 0; i2 < fieldsAsInstances.classAttribute().numValues(); i2++) {
                    dArr[i2] = targetMetaData.getPriorProbability(fieldsAsInstances.classAttribute().value(i2));
                }
            }
            return dArr;
        }
        for (int i3 = 0; i3 < this.m_regressionTables.length; i3++) {
            this.m_regressionTables[i3].predict(dArr, instanceToSchema);
        }
        switch (this.m_normalizationMethod) {
            case NONE:
                break;
            case SIMPLEMAX:
                Utils.normalize(dArr);
                break;
            case SOFTMAX:
                for (int i4 = 0; i4 < dArr.length; i4++) {
                    dArr[i4] = Math.exp(dArr[i4]);
                }
                if (dArr.length == 1) {
                    dArr[0] = dArr[0] / (dArr[0] + 1.0d);
                    break;
                } else {
                    Utils.normalize(dArr);
                    break;
                }
            case LOGIT:
                for (int i5 = 0; i5 < dArr.length; i5++) {
                    dArr[i5] = 1.0d / (1.0d + Math.exp(-dArr[i5]));
                }
                Utils.normalize(dArr);
                break;
            case PROBIT:
                for (int i6 = 0; i6 < dArr.length; i6++) {
                    dArr[i6] = Maths.pnorm(dArr[i6]);
                }
                Utils.normalize(dArr);
                break;
            case CLOGLOG:
                for (int i7 = 0; i7 < dArr.length; i7++) {
                    dArr[i7] = 1.0d - Math.exp(-Math.exp(-dArr[i7]));
                }
                Utils.normalize(dArr);
                break;
            case EXP:
                for (int i8 = 0; i8 < dArr.length; i8++) {
                    dArr[i8] = Math.exp(dArr[i8]);
                }
                Utils.normalize(dArr);
                break;
            case LOGLOG:
                for (int i9 = 0; i9 < dArr.length; i9++) {
                    dArr[i9] = Math.exp(-Math.exp(-dArr[i9]));
                }
                Utils.normalize(dArr);
                break;
            case CAUCHIT:
                for (int i10 = 0; i10 < dArr.length; i10++) {
                    dArr[i10] = 0.5d + (0.3183098861837907d * Math.atan(dArr[i10]));
                }
                Utils.normalize(dArr);
                break;
            default:
                throw new Exception("[Regression] unknown normalization method");
        }
        if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric() && this.m_miningSchema.hasTargetMetaData()) {
            dArr[0] = this.m_miningSchema.getTargetMetaData().applyMinMaxRescaleCast(dArr[0]);
        }
        return dArr;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }
}
