package com.oracle.labs.mlrg.olcut.config.protobuf;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.TextFormat;
import com.oracle.labs.mlrg.olcut.config.protobuf.protos.ListProvenanceProto;
import com.oracle.labs.mlrg.olcut.config.protobuf.protos.MapProvenanceProto;
import com.oracle.labs.mlrg.olcut.config.protobuf.protos.ObjectProvenanceProto;
import com.oracle.labs.mlrg.olcut.config.protobuf.protos.RootProvenanceProto;
import com.oracle.labs.mlrg.olcut.config.protobuf.protos.SimpleProvenanceProto;
import com.oracle.labs.mlrg.olcut.provenance.io.FlatMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.provenance.io.MapMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization;
import com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/oracle/labs/mlrg/olcut/config/protobuf/ProtoProvenanceSerialization.class */
public final class ProtoProvenanceSerialization implements ProvenanceSerialization {
    private static final Base64.Encoder base64Encoder = Base64.getEncoder();
    private static final Base64.Decoder base64Decoder = Base64.getDecoder();
    private final boolean textFormat;

    public ProtoProvenanceSerialization(boolean z) {
        this.textFormat = z;
    }

    public String getFileExtension() {
        return this.textFormat ? "pbtxt" : "pb";
    }

    public List<ObjectMarshalledProvenance> deserializeFromFile(Path path) throws IOException {
        RootProvenanceProto parseFrom;
        try {
            InputStream newInputStream = Files.newInputStream(path, new OpenOption[0]);
            if (this.textFormat) {
                RootProvenanceProto.Builder newBuilder = RootProvenanceProto.newBuilder();
                TextFormat.getParser().merge(new InputStreamReader(newInputStream, StandardCharsets.UTF_8), newBuilder);
                parseFrom = newBuilder.m426build();
            } else {
                parseFrom = RootProvenanceProto.parseFrom(newInputStream);
            }
            return deserializeFromProto(parseFrom);
        } catch (InvalidProtocolBufferException | TextFormat.ParseException e) {
            throw new IllegalArgumentException("Failed to parse protobuf", e);
        }
    }

    public List<ObjectMarshalledProvenance> deserializeFromString(String str) {
        RootProvenanceProto parseFrom;
        try {
            if (this.textFormat) {
                RootProvenanceProto.Builder newBuilder = RootProvenanceProto.newBuilder();
                TextFormat.getParser().merge(str, newBuilder);
                parseFrom = newBuilder.m426build();
            } else {
                parseFrom = RootProvenanceProto.parseFrom(base64Decoder.decode(str));
            }
            return deserializeFromProto(parseFrom);
        } catch (InvalidProtocolBufferException | TextFormat.ParseException e) {
            throw new IllegalArgumentException("Failed to parse protobuf", e);
        }
    }

    public List<ObjectMarshalledProvenance> deserializeFromProto(RootProvenanceProto rootProvenanceProto) {
        Message[] messageArr = new Message[rootProvenanceProto.getLmpCount() + rootProvenanceProto.getMmpCount() + rootProvenanceProto.getOmpCount() + rootProvenanceProto.getSmpCount()];
        for (ObjectProvenanceProto objectProvenanceProto : rootProvenanceProto.getOmpList()) {
            int index = objectProvenanceProto.getIndex();
            if (messageArr[index] != null) {
                throw new IllegalArgumentException("Invalid protobuf found, index " + index + " collided, found '" + objectProvenanceProto.toString() + " and " + messageArr[index].toString());
            }
            messageArr[index] = objectProvenanceProto;
        }
        for (SimpleProvenanceProto simpleProvenanceProto : rootProvenanceProto.getSmpList()) {
            int index2 = simpleProvenanceProto.getIndex();
            if (messageArr[index2] != null) {
                throw new IllegalArgumentException("Invalid protobuf found, index " + index2 + " collided, found '" + simpleProvenanceProto.toString() + " and " + messageArr[index2].toString());
            }
            messageArr[index2] = simpleProvenanceProto;
        }
        for (MapProvenanceProto mapProvenanceProto : rootProvenanceProto.getMmpList()) {
            int index3 = mapProvenanceProto.getIndex();
            if (messageArr[index3] != null) {
                throw new IllegalArgumentException("Invalid protobuf found, index " + index3 + " collided, found '" + mapProvenanceProto.toString() + " and " + messageArr[index3].toString());
            }
            messageArr[index3] = mapProvenanceProto;
        }
        for (ListProvenanceProto listProvenanceProto : rootProvenanceProto.getLmpList()) {
            int index4 = listProvenanceProto.getIndex();
            if (messageArr[index4] != null) {
                throw new IllegalArgumentException("Invalid protobuf found, index " + index4 + " collided, found '" + listProvenanceProto.toString() + " and " + messageArr[index4].toString());
            }
            messageArr[index4] = listProvenanceProto;
        }
        ArrayList arrayList = new ArrayList();
        for (ObjectProvenanceProto objectProvenanceProto2 : rootProvenanceProto.getOmpList()) {
            HashMap hashMap = new HashMap();
            for (Map.Entry<String, Integer> entry : objectProvenanceProto2.getValuesMap().entrySet()) {
                hashMap.put(entry.getKey(), dispatchMessage(messageArr, entry.getValue().intValue()));
            }
            arrayList.add(new ObjectMarshalledProvenance(objectProvenanceProto2.getObjectName(), hashMap, objectProvenanceProto2.getObjectClassName(), objectProvenanceProto2.getProvenanceClassName()));
        }
        return arrayList;
    }

    private static FlatMarshalledProvenance dispatchMessage(Message[] messageArr, int i) {
        Message message = messageArr[i];
        if (message instanceof SimpleProvenanceProto) {
            return decodeSMP((SimpleProvenanceProto) message);
        }
        if (message instanceof ListProvenanceProto) {
            return decodeLMP(messageArr, (ListProvenanceProto) message);
        }
        if (message instanceof MapProvenanceProto) {
            return decodeMMP(messageArr, (MapProvenanceProto) message);
        }
        throw new IllegalStateException("Invalid protobuf, a message index points to an ObjectMarshalledProvenance");
    }

    private static SimpleMarshalledProvenance decodeSMP(SimpleProvenanceProto simpleProvenanceProto) {
        return new SimpleMarshalledProvenance(simpleProvenanceProto.getKey(), simpleProvenanceProto.getValue(), simpleProvenanceProto.getProvenanceClassName(), simpleProvenanceProto.getIsReference(), simpleProvenanceProto.getAdditional());
    }

    private static ListMarshalledProvenance decodeLMP(Message[] messageArr, ListProvenanceProto listProvenanceProto) {
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = listProvenanceProto.getValuesList().iterator();
        while (it.hasNext()) {
            arrayList.add(dispatchMessage(messageArr, it.next().intValue()));
        }
        return new ListMarshalledProvenance(arrayList);
    }

    private static MapMarshalledProvenance decodeMMP(Message[] messageArr, MapProvenanceProto mapProvenanceProto) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Integer> entry : mapProvenanceProto.getValuesMap().entrySet()) {
            hashMap.put(entry.getKey(), dispatchMessage(messageArr, entry.getValue().intValue()));
        }
        return new MapMarshalledProvenance(hashMap);
    }

    public RootProvenanceProto serializeToProto(List<ObjectMarshalledProvenance> list) {
        RootProvenanceProto.Builder newBuilder = RootProvenanceProto.newBuilder();
        MutableLong mutableLong = new MutableLong(0L);
        Iterator<ObjectMarshalledProvenance> it = list.iterator();
        while (it.hasNext()) {
            convertProvenance(newBuilder, mutableLong, it.next());
        }
        return newBuilder.m426build();
    }

    private static void convertProvenance(RootProvenanceProto.Builder builder, MutableLong mutableLong, ObjectMarshalledProvenance objectMarshalledProvenance) {
        ObjectProvenanceProto.Builder newBuilder = ObjectProvenanceProto.newBuilder();
        newBuilder.setIndex(mutableLong.intValue());
        mutableLong.increment();
        newBuilder.setObjectName(objectMarshalledProvenance.getName());
        newBuilder.setObjectClassName(objectMarshalledProvenance.getObjectClassName());
        newBuilder.setProvenanceClassName(objectMarshalledProvenance.getProvenanceClassName());
        for (Map.Entry entry : objectMarshalledProvenance.getMap().entrySet()) {
            newBuilder.putValues((String) entry.getKey(), dispatchFMP(builder, mutableLong, (FlatMarshalledProvenance) entry.getValue()));
        }
        builder.addOmp(newBuilder.m280build());
    }

    private static int dispatchFMP(RootProvenanceProto.Builder builder, MutableLong mutableLong, FlatMarshalledProvenance flatMarshalledProvenance) {
        if (flatMarshalledProvenance instanceof SimpleMarshalledProvenance) {
            return encodeSMP(builder, mutableLong, (SimpleMarshalledProvenance) flatMarshalledProvenance);
        }
        if (flatMarshalledProvenance instanceof ListMarshalledProvenance) {
            return encodeLMP(builder, mutableLong, (ListMarshalledProvenance) flatMarshalledProvenance);
        }
        if (flatMarshalledProvenance instanceof MapMarshalledProvenance) {
            return encodeMMP(builder, mutableLong, (MapMarshalledProvenance) flatMarshalledProvenance);
        }
        throw new RuntimeException("Should not reach here, unexpected FlatMarshalledProvenance subclass " + flatMarshalledProvenance.getClass());
    }

    private static int encodeSMP(RootProvenanceProto.Builder builder, MutableLong mutableLong, SimpleMarshalledProvenance simpleMarshalledProvenance) {
        SimpleProvenanceProto.Builder newBuilder = SimpleProvenanceProto.newBuilder();
        int intValue = mutableLong.intValue();
        newBuilder.setIndex(intValue);
        mutableLong.increment();
        newBuilder.setKey(simpleMarshalledProvenance.getKey());
        newBuilder.setValue(simpleMarshalledProvenance.getValue());
        newBuilder.setAdditional(simpleMarshalledProvenance.getAdditional());
        newBuilder.setProvenanceClassName(simpleMarshalledProvenance.getProvenanceClassName());
        newBuilder.setIsReference(simpleMarshalledProvenance.isReference());
        builder.addSmp(newBuilder.m520build());
        return intValue;
    }

    private static int encodeLMP(RootProvenanceProto.Builder builder, MutableLong mutableLong, ListMarshalledProvenance listMarshalledProvenance) {
        ListProvenanceProto.Builder newBuilder = ListProvenanceProto.newBuilder();
        int intValue = mutableLong.intValue();
        newBuilder.setIndex(intValue);
        mutableLong.increment();
        Iterator it = listMarshalledProvenance.iterator();
        while (it.hasNext()) {
            newBuilder.addValues(dispatchFMP(builder, mutableLong, (FlatMarshalledProvenance) it.next()));
        }
        builder.addLmp(newBuilder.m185build());
        return intValue;
    }

    private static int encodeMMP(RootProvenanceProto.Builder builder, MutableLong mutableLong, MapMarshalledProvenance mapMarshalledProvenance) {
        MapProvenanceProto.Builder newBuilder = MapProvenanceProto.newBuilder();
        int intValue = mutableLong.intValue();
        newBuilder.setIndex(intValue);
        mutableLong.increment();
        Iterator it = mapMarshalledProvenance.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            newBuilder.putValues((String) pair.getA(), dispatchFMP(builder, mutableLong, (FlatMarshalledProvenance) pair.getB()));
        }
        builder.addMmp(newBuilder.m232build());
        return intValue;
    }

    public String serializeToString(List<ObjectMarshalledProvenance> list) {
        RootProvenanceProto serializeToProto = serializeToProto(list);
        return this.textFormat ? serializeToProto.toString() : base64Encoder.encodeToString(serializeToProto.toByteArray());
    }

    public void serializeToFile(List<ObjectMarshalledProvenance> list, Path path) throws IOException {
        RootProvenanceProto serializeToProto = serializeToProto(list);
        if (this.textFormat) {
            PrintWriter printWriter = new PrintWriter(Files.newBufferedWriter(path, new OpenOption[0]));
            Throwable th = null;
            try {
                try {
                    printWriter.println(serializeToProto.toString());
                    if (printWriter != null) {
                        if (0 == 0) {
                            printWriter.close();
                            return;
                        }
                        try {
                            printWriter.close();
                            return;
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                            return;
                        }
                    }
                    return;
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } catch (Throwable th4) {
                if (printWriter != null) {
                    if (th != null) {
                        try {
                            printWriter.close();
                        } catch (Throwable th5) {
                            th.addSuppressed(th5);
                        }
                    } else {
                        printWriter.close();
                    }
                }
                throw th4;
            }
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(Files.newOutputStream(path, new OpenOption[0]));
        Throwable th6 = null;
        try {
            try {
                serializeToProto.writeTo(bufferedOutputStream);
                if (bufferedOutputStream != null) {
                    if (0 == 0) {
                        bufferedOutputStream.close();
                        return;
                    }
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th7) {
                        th6.addSuppressed(th7);
                    }
                }
            } catch (Throwable th8) {
                th6 = th8;
                throw th8;
            }
        } catch (Throwable th9) {
            if (bufferedOutputStream != null) {
                if (th6 != null) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th10) {
                        th6.addSuppressed(th10);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th9;
        }
    }
}
