/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.nodes.physical.stream;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.AbstractRelNode;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCallBinding;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.flink.calcite.shaded.org.checkerframework.checker.nullness.qual.Nullable;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableSet;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.catalog.ContextResolvedFunction;
import org.apache.flink.table.connector.ChangelogMode;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionIdentifier;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.calcite.RexTableArgCall;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
import org.apache.flink.table.planner.functions.inference.OperatorBindingCallContext;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecProcessTableFunction;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRel;
import org.apache.flink.table.planner.plan.utils.ChangelogPlanUtils;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.StaticArgument;
import org.apache.flink.table.types.inference.StaticArgumentTrait;
import org.apache.flink.table.types.inference.SystemTypeInference;
import org.apache.flink.types.RowKind;

public class StreamPhysicalProcessTableFunction
extends AbstractRelNode
implements StreamPhysicalRel {
    private final FlinkLogicalTableFunctionScan scan;
    private final @Nullable String uid;
    private List<RelNode> inputs;

    public StreamPhysicalProcessTableFunction(RelOptCluster cluster, RelTraitSet traitSet, List<RelNode> inputs, FlinkLogicalTableFunctionScan scan, RelDataType rowType) {
        super(cluster, traitSet);
        this.inputs = inputs;
        this.rowType = rowType;
        this.scan = scan;
        this.uid = StreamPhysicalProcessTableFunction.deriveUniqueIdentifier(scan);
        StreamPhysicalProcessTableFunction.verifyInputSize(ShortcutUtils.unwrapTableConfig(cluster), inputs.size());
    }

    public StreamPhysicalProcessTableFunction(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, FlinkLogicalTableFunctionScan scan, RelDataType rowType) {
        this(cluster, traitSet, List.of(input), scan, rowType);
    }

    public RexCall getCall() {
        return (RexCall)this.scan.getCall();
    }

    @Override
    public boolean requireWatermark() {
        return true;
    }

    @Override
    public List<RelNode> getInputs() {
        return this.inputs;
    }

    @Override
    public void replaceInput(int ordinalInParent, RelNode p) {
        ArrayList<RelNode> newInputs = new ArrayList<RelNode>(this.inputs);
        newInputs.set(ordinalInParent, p);
        this.inputs = List.copyOf(newInputs);
        this.recomputeDigest();
    }

    @Override
    public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
        return new StreamPhysicalProcessTableFunction(this.getCluster(), traitSet, inputs, this.scan, this.getRowType());
    }

    @Override
    public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
        double elementRate = 100.0 * (double)this.getInputs().size();
        return planner.getCostFactory().makeCost(elementRate, elementRate, 0.0);
    }

    @Override
    public ExecNode<?> translateToExecNode() {
        List<ChangelogMode> inputChangelogModes = this.getInputs().stream().map(StreamPhysicalRel.class::cast).map(ChangelogPlanUtils::getChangelogMode).map(JavaScalaConversionUtil::toJava).map(optional -> (ChangelogMode)optional.orElseThrow(IllegalStateException::new)).collect(Collectors.toList());
        ChangelogMode outputChangelogMode = JavaScalaConversionUtil.toJava(ChangelogPlanUtils.getChangelogMode(this)).orElseThrow(IllegalStateException::new);
        RexCall call = (RexCall)this.scan.getCall();
        StreamPhysicalProcessTableFunction.verifyTimeAttributes(this.getInputs(), call, inputChangelogModes, outputChangelogMode);
        List<Ord<StaticArgument>> providedInputArgs = StreamPhysicalProcessTableFunction.getProvidedInputArgs(call);
        StreamPhysicalProcessTableFunction.verifyPassThroughColumnsForUpdates(providedInputArgs, outputChangelogMode);
        return new StreamExecProcessTableFunction((ReadableConfig)ShortcutUtils.unwrapTableConfig(this), this.getInputs().stream().map(i -> InputProperty.DEFAULT).collect(Collectors.toList()), FlinkTypeFactory.toLogicalRowType(this.rowType), this.getRelDetailedDescription(), this.uid, call, inputChangelogModes, outputChangelogMode);
    }

    @Override
    public RelWriter explainTerms(RelWriter pw) {
        super.explainTerms(pw);
        for (Ord<RelNode> ord : Ord.zip(this.inputs)) {
            pw.input("input#" + ord.i, (RelNode)ord.e);
        }
        return pw.item("invocation", this.scan.getCall()).item("uid", this.uid).item("select", String.join((CharSequence)",", this.getRowType().getFieldNames())).item("rowType", this.getRowType());
    }

    @Override
    protected RelDataType deriveRowType() {
        return this.rowType;
    }

    private static @Nullable String deriveUniqueIdentifier(FlinkLogicalTableFunctionScan scan) {
        RexCall rexCall = (RexCall)scan.getCall();
        BridgingSqlFunction.WithTableFunction function = (BridgingSqlFunction.WithTableFunction)rexCall.getOperator();
        List staticArgs = (List)function.getTypeInference().getStaticArguments().orElseThrow(IllegalStateException::new);
        ContextResolvedFunction resolvedFunction = function.getResolvedFunction();
        List<RexNode> operands = rexCall.getOperands();
        RexNode uidRexNode = operands.get(operands.size() - 1);
        if (uidRexNode.getKind() == SqlKind.DEFAULT) {
            if (staticArgs.stream().noneMatch(arg -> arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE))) {
                return null;
            }
            String uid = resolvedFunction.getIdentifier().map(FunctionIdentifier::getFunctionName).orElse("");
            if (SystemTypeInference.isInvalidUidForProcessTableFunction((String)uid)) {
                throw new ValidationException(String.format("Could not derive a unique identifier for process table function '%s'. The function's name does not qualify for a UID. Please provide a custom identifier using the implicit `uid` argument. For example: myFunction(..., uid => 'my-id')", resolvedFunction.asSummaryString()));
            }
            return uid;
        }
        return RexLiteral.stringValue(uidRexNode);
    }

    private static void verifyTimeAttributes(List<RelNode> inputs, RexCall call, List<ChangelogMode> inputChangelogModes, ChangelogMode outputChangelogMode) {
        Set<String> onTimeFields = StreamPhysicalProcessTableFunction.deriveOnTimeFields(call);
        StreamPhysicalProcessTableFunction.verifyOnTimeForUpdates(onTimeFields, inputChangelogModes, outputChangelogMode);
        inputs.stream().map(RelNode::getRowType).forEach(rowType -> StreamPhysicalProcessTableFunction.verifyTimeAttribute(rowType, onTimeFields));
    }

    private static void verifyTimeAttribute(RelDataType rowType, Set<String> onTimeFields) {
        onTimeFields.stream().map(onTimeField -> rowType.getField((String)onTimeField, true, false)).filter(Objects::nonNull).forEach(StreamPhysicalProcessTableFunction::verifyTimeAttribute);
    }

    private static void verifyTimeAttribute(RelDataTypeField timeColumn) {
        if (!FlinkTypeFactory.isTimeIndicatorType(timeColumn.getType())) {
            throw new ValidationException(String.format("Column '%s' is not a valid time attribute. Only columns with a watermark declaration qualify for the `on_time` argument. Also, make sure that the watermarked column is forwarded without any modification.", timeColumn.getName()));
        }
    }

    private static void verifyOnTimeForUpdates(Set<String> onTimeFields, List<ChangelogMode> inputChangelogModes, ChangelogMode outputChangelogMode) {
        boolean isUpdating;
        if (onTimeFields.isEmpty()) {
            return;
        }
        boolean bl = isUpdating = inputChangelogModes.stream().anyMatch(c -> !c.containsOnly(RowKind.INSERT)) || !outputChangelogMode.containsOnly(RowKind.INSERT);
        if (isUpdating) {
            throw new ValidationException("Time operations using the `on_time` argument are currently not supported for PTFs that consume or produce updates.");
        }
    }

    private static void verifyPassThroughColumnsForUpdates(List<Ord<StaticArgument>> providedInputArgs, ChangelogMode requiredChangelogMode) {
        if (!requiredChangelogMode.containsOnly(RowKind.INSERT) && providedInputArgs.stream().anyMatch(arg -> ((StaticArgument)arg.e).is(StaticArgumentTrait.PASS_COLUMNS_THROUGH))) {
            throw new ValidationException("Pass-through columns are not supported for PTFs that produce updates.");
        }
    }

    private static void verifyInputSize(TableConfig tableConfig, int providedInputArgs) {
        int maxCount = (Integer)tableConfig.get(OptimizerConfigOptions.TABLE_OPTIMIZER_PTF_MAX_TABLES);
        if (providedInputArgs > maxCount) {
            throw new ValidationException(String.format("Unsupported table argument count. Currently, the number of input tables is limited to %s.", maxCount));
        }
    }

    public static List<Ord<StaticArgument>> getProvidedInputArgs(RexCall call) {
        List<RexNode> operands = call.getOperands();
        BridgingSqlFunction.WithTableFunction function = (BridgingSqlFunction.WithTableFunction)call.getOperator();
        List declaredArgs = (List)function.getTypeInference().getStaticArguments().orElseThrow(IllegalStateException::new);
        return Ord.zip(declaredArgs).stream().filter(arg -> ((StaticArgument)arg.e).is(StaticArgumentTrait.TABLE)).filter(arg -> operands.get(arg.i) instanceof RexTableArgCall).collect(Collectors.toList());
    }

    public static Set<String> deriveOnTimeFields(RexCall call) {
        List<RexNode> operands = call.getOperands();
        RexCall onTimeOperand = (RexCall)operands.get(operands.size() - 1 - 1);
        if (onTimeOperand.getKind() == SqlKind.DEFAULT) {
            return Set.of();
        }
        return onTimeOperand.getOperands().stream().map(RexLiteral::stringValue).filter(Objects::nonNull).collect(Collectors.toSet());
    }

    public static RexCall toUdfCall(RexCall call) {
        BridgingSqlFunction function = ShortcutUtils.unwrapBridgingSqlFunction(call);
        assert (function != null);
        List staticArgs = (List)function.getTypeInference().getStaticArguments().orElseThrow(IllegalStateException::new);
        List<RexNode> operands = call.getOperands();
        List<RexNode> newOperands = operands.subList(0, operands.size() - SystemTypeInference.PROCESS_TABLE_FUNCTION_SYSTEM_ARGS.size());
        int prefixOutputSystemFields = Ord.zip(newOperands).stream().mapToInt(operand -> {
            if (!(operand.e instanceof RexTableArgCall)) {
                return 0;
            }
            RexTableArgCall tableArg = (RexTableArgCall)operand.e;
            StaticArgument staticArg = (StaticArgument)staticArgs.get(0);
            if (staticArg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) {
                return tableArg.getType().getFieldCount();
            }
            return tableArg.getPartitionKeys().length;
        }).sum();
        RexNode onTimeArg = operands.get(operands.size() - 1 - 1);
        int suffixOutputSystemFields = onTimeArg.getKind() == SqlKind.DEFAULT || RexUtil.isNullLiteral(onTimeArg, true) ? 0 : 1;
        List<RelDataTypeField> projectedFields = call.getType().getFieldList().subList(prefixOutputSystemFields, call.getType().getFieldCount() - suffixOutputSystemFields);
        RelDataType newReturnType = function.getTypeFactory().createStructType(projectedFields);
        return call.clone(newReturnType, newOperands);
    }

    public static List<Integer> toInputTimeColumns(RexCall call) {
        List<RexNode> operands = call.getOperands();
        Set<String> onTimeFields = StreamPhysicalProcessTableFunction.deriveOnTimeFields(call);
        List<Ord<StaticArgument>> providedInputArgs = StreamPhysicalProcessTableFunction.getProvidedInputArgs(call);
        return providedInputArgs.stream().map(providedInputArg -> {
            RexTableArgCall tableArgCall = (RexTableArgCall)operands.get(providedInputArg.i);
            return onTimeFields.stream().map(onTimeField -> tableArgCall.getType().getField((String)onTimeField, true, false)).filter(Objects::nonNull).map(RelDataTypeField::getIndex).findFirst().orElse(-1);
        }).collect(Collectors.toList());
    }

    public static Set<ImmutableBitSet> toPartitionColumns(RexCall call) {
        List<RexNode> operands = call.getOperands();
        List<Ord<StaticArgument>> providedInputArgs = StreamPhysicalProcessTableFunction.getProvidedInputArgs(call);
        HashSet<ImmutableBitSet> partitionColumnsPerArg = new HashSet<ImmutableBitSet>();
        int pos = 0;
        for (Ord<StaticArgument> providedInputArg : providedInputArgs) {
            RexTableArgCall tableArgCall = (RexTableArgCall)operands.get(providedInputArg.i);
            if (((StaticArgument)providedInputArg.e).is(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) {
                assert (providedInputArgs.size() == 1);
                List<Integer> partitionColumns = Arrays.stream(tableArgCall.getPartitionKeys()).boxed().collect(Collectors.toList());
                partitionColumnsPerArg.add(ImmutableBitSet.of(partitionColumns));
                continue;
            }
            int partitionKeyCount = tableArgCall.getPartitionKeys().length;
            List<Integer> partitionColumns = IntStream.range(pos, partitionKeyCount).boxed().collect(Collectors.toList());
            pos += partitionKeyCount;
            partitionColumnsPerArg.add(ImmutableBitSet.of(partitionColumns));
        }
        return ImmutableSet.copyOf(partitionColumnsPerArg);
    }

    public static CallContext toCallContext(RexCall udfCall, List<Integer> inputTimeColumns, List<ChangelogMode> inputChangelogModes, @Nullable ChangelogMode outputChangelogMode) {
        BridgingSqlFunction function = ShortcutUtils.unwrapBridgingSqlFunction(udfCall);
        assert (function != null);
        FunctionDefinition definition = ShortcutUtils.unwrapFunctionDefinition(udfCall);
        return new OperatorBindingCallContext(function.getDataTypeFactory(), definition, RexCallBinding.create(function.getTypeFactory(), udfCall, Collections.emptyList()), udfCall.getType(), inputTimeColumns, inputChangelogModes, outputChangelogMode);
    }
}

