/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.statefun.flink.core.reqreply;

import com.google.protobuf.MoreByteStrings;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.statefun.flink.core.types.remote.RemoteValueTypeMismatchException;
import org.apache.flink.statefun.sdk.TypeName;
import org.apache.flink.statefun.sdk.annotations.Persisted;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.statefun.sdk.state.Expiration;
import org.apache.flink.statefun.sdk.state.PersistedStateRegistry;
import org.apache.flink.statefun.sdk.state.RemotePersistedValue;

public final class PersistedRemoteFunctionValues {
    private static final TypeName UNSET_STATE_TYPE = TypeName.parseFrom((String)"io.statefun.types/unset");
    @Persisted
    private final PersistedStateRegistry stateRegistry = new PersistedStateRegistry();
    private final Map<String, RemotePersistedValue> managedStates = new HashMap<String, RemotePersistedValue>();

    void attachStateValues(ToFunction.InvocationBatchRequest.Builder batchBuilder) {
        for (Map.Entry<String, RemotePersistedValue> managedStateEntry : this.managedStates.entrySet()) {
            ToFunction.PersistedValue.Builder valueBuilder = ToFunction.PersistedValue.newBuilder().setStateName(managedStateEntry.getKey());
            RemotePersistedValue registeredHandle = managedStateEntry.getValue();
            byte[] stateBytes = registeredHandle.get();
            if (stateBytes != null) {
                TypedValue stateValue = TypedValue.newBuilder().setTypename(registeredHandle.type().canonicalTypenameString()).setHasValue(true).setValue(MoreByteStrings.wrap((byte[])stateBytes)).build();
                valueBuilder.setStateValue(stateValue);
            }
            batchBuilder.addState(valueBuilder);
        }
    }

    void updateStateValues(List<FromFunction.PersistedValueMutation> valueMutations) {
        block5: for (FromFunction.PersistedValueMutation mutate : valueMutations) {
            String stateName = mutate.getStateName();
            switch (mutate.getMutationType()) {
                case DELETE: {
                    this.getStateHandleOrThrow(stateName).clear();
                    continue block5;
                }
                case MODIFY: {
                    RemotePersistedValue registeredHandle = this.getStateHandleOrThrow(stateName);
                    TypedValue newStateValue = mutate.getStateValue();
                    this.validateType(registeredHandle, newStateValue.getTypename());
                    registeredHandle.set(newStateValue.getValue().toByteArray());
                    continue block5;
                }
                case UNRECOGNIZED: {
                    throw new IllegalStateException("Received an UNRECOGNIZED PersistedValueMutation type. This may be caused by a mismatch or incompatibility with the remote function SDK version and the Stateful Functions version.");
                }
            }
            throw new IllegalStateException("Unexpected value: " + (Object)((Object)mutate.getMutationType()));
        }
    }

    void registerStates(List<FromFunction.PersistedValueSpec> protocolPersistedValueSpecs) {
        protocolPersistedValueSpecs.forEach(this::createAndRegisterValueStateIfAbsent);
    }

    private void createAndRegisterValueStateIfAbsent(FromFunction.PersistedValueSpec protocolPersistedValueSpec) {
        RemotePersistedValue stateHandle = this.managedStates.get(protocolPersistedValueSpec.getStateName());
        if (stateHandle == null) {
            this.registerValueState(protocolPersistedValueSpec);
        } else {
            this.validateType(stateHandle, protocolPersistedValueSpec.getTypeTypename());
        }
    }

    private void registerValueState(FromFunction.PersistedValueSpec protocolPersistedValueSpec) {
        String stateName = protocolPersistedValueSpec.getStateName();
        RemotePersistedValue remoteValueState = RemotePersistedValue.of((String)stateName, (TypeName)PersistedRemoteFunctionValues.sdkStateType(protocolPersistedValueSpec.getTypeTypename()), (Expiration)PersistedRemoteFunctionValues.sdkTtlExpiration(protocolPersistedValueSpec.getExpirationSpec()));
        this.managedStates.put(stateName, remoteValueState);
        try {
            this.stateRegistry.registerRemoteValue(remoteValueState);
        }
        catch (RemoteValueTypeMismatchException e) {
            throw new RemoteFunctionStateException(stateName, e);
        }
    }

    private void validateType(RemotePersistedValue previousStateHandle, String protocolTypenameString) {
        TypeName newStateType = PersistedRemoteFunctionValues.sdkStateType(protocolTypenameString);
        if (!newStateType.equals((Object)previousStateHandle.type())) {
            throw new RemoteFunctionStateException(previousStateHandle.name(), new RemoteValueTypeMismatchException(previousStateHandle.type(), newStateType));
        }
    }

    private static TypeName sdkStateType(String protocolTypenameString) {
        return protocolTypenameString.isEmpty() ? UNSET_STATE_TYPE : TypeName.parseFrom((String)protocolTypenameString);
    }

    private static Expiration sdkTtlExpiration(FromFunction.ExpirationSpec protocolExpirationSpec) {
        long expirationTtlMillis = protocolExpirationSpec.getExpireAfterMillis();
        switch (protocolExpirationSpec.getMode()) {
            case AFTER_INVOKE: {
                return Expiration.expireAfterReadingOrWriting((Duration)Duration.ofMillis(expirationTtlMillis));
            }
            case AFTER_WRITE: {
                return Expiration.expireAfterWriting((Duration)Duration.ofMillis(expirationTtlMillis));
            }
        }
        return Expiration.none();
    }

    private RemotePersistedValue getStateHandleOrThrow(String stateName) {
        RemotePersistedValue handle = this.managedStates.get(stateName);
        if (handle == null) {
            throw new IllegalStateException("Accessing a non-existing function state: " + stateName + ". This can happen if you forgot to declare this state using the language SDKs.");
        }
        return handle;
    }

    public static class RemoteFunctionStateException
    extends RuntimeException {
        private static final long serialVersionUID = 1L;

        private RemoteFunctionStateException(String stateName, Throwable cause) {
            super("An error occurred for state [" + stateName + "].", cause);
        }
    }
}

