/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.cassandra.spark.reader;

import java.io.DataInputStream;
import java.io.EOFException;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumMap;
import java.util.EnumSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.zip.CRC32;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.primitives.Ints;

import org.apache.cassandra.bridge.TokenRange;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.ClusteringPrefix;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.SerializationHeader;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.ByteBufferAccessor;
import org.apache.cassandra.db.marshal.CompositeType;
import org.apache.cassandra.db.marshal.TypeParser;
import org.apache.cassandra.db.marshal.UTF8Type;
import org.apache.cassandra.db.rows.EncodingStats;
import org.apache.cassandra.dht.IPartitioner;
import org.apache.cassandra.io.sstable.CorruptSSTableException;
import org.apache.cassandra.io.sstable.Descriptor;
import org.apache.cassandra.io.sstable.format.SSTableFormat;
import org.apache.cassandra.io.sstable.format.Version;
import org.apache.cassandra.io.sstable.format.bti.BtiReaderUtils;
import org.apache.cassandra.io.sstable.format.bti.PartitionIndex;
import org.apache.cassandra.io.sstable.metadata.MetadataComponent;
import org.apache.cassandra.io.sstable.metadata.MetadataType;
import org.apache.cassandra.io.sstable.metadata.ValidationMetadata;
import org.apache.cassandra.io.util.ChannelProxy;
import org.apache.cassandra.io.util.DataInputBuffer;
import org.apache.cassandra.io.util.DataInputPlus;
import org.apache.cassandra.io.util.DataInputStreamPlus;
import org.apache.cassandra.io.util.RebufferingChannelInputStream;
import org.apache.cassandra.io.util.File;
import org.apache.cassandra.io.util.FileHandle;
import org.apache.cassandra.io.util.ReadOnlyInputStreamFileChannel;
import org.apache.cassandra.schema.TableMetadata;
import org.apache.cassandra.spark.data.FileType;
import org.apache.cassandra.spark.data.SSTable;
import org.apache.cassandra.spark.sparksql.filters.PartitionKeyFilter;
import org.apache.cassandra.spark.utils.ByteBufferUtils;
import org.apache.cassandra.spark.utils.Pair;
import org.apache.cassandra.spark.utils.Preconditions;
import org.apache.cassandra.spark.utils.streaming.BufferingInputStream;
import org.apache.cassandra.utils.BloomFilter;
import org.apache.cassandra.utils.BloomFilterSerializer;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.TokenUtils;
import org.apache.cassandra.utils.vint.VIntCoding;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import static org.apache.cassandra.utils.FBUtilities.updateChecksumInt;

@SuppressWarnings("WeakerAccess")
public final class ReaderUtils extends TokenUtils
{
    private static final int CHECKSUM_LENGTH = 4;  // CRC32
    private static final Constructor<?> SERIALIZATION_HEADER =
    Arrays.stream(SerializationHeader.Component.class.getDeclaredConstructors())
          .filter(constructor -> constructor.getParameterCount() == 5)
          .findFirst()
          .orElseThrow(() -> new RuntimeException("Could not find SerializationHeader.Component constructor"));
    public static final ByteBuffer SUPER_COLUMN_MAP_COLUMN = ByteBufferUtil.EMPTY_BYTE_BUFFER;

    public static Descriptor constructDescriptor(@NotNull String keyspace, @NotNull String table, @NotNull SSTable ssTable)
    {
        File file = ReaderUtils.constructFilename(keyspace, table, ssTable.getDataFileName());
        return Descriptor.fromFile(file);
    }

    /**
     * Constructs full file path for a given combination of keyspace, table, and data file name,
     * while adjusting for data files with non-standard names prefixed with keyspace and table
     *
     * @param keyspace Name of the keyspace
     * @param table    Name of the table
     * @param filename Name of the data file
     * @return A full file path, adjusted for non-standard file names
     */
    @VisibleForTesting
    @NotNull
    public static File constructFilename(@NotNull String keyspace, @NotNull String table, @NotNull String filename)
    {
        String[] components = filename.split("-");
        if (components.length == 6
            && components[0].equals(keyspace)
            && components[1].equals(table))
        {
            filename = filename.substring(keyspace.length() + table.length() + 2);
        }

        return new File(String.format("./%s/%s", keyspace, table), filename);
    }

    static
    {
        SERIALIZATION_HEADER.setAccessible(true);
    }

    private ReaderUtils()
    {
        super();
        throw new IllegalStateException(getClass() + " is static utility class and shall not be instantiated");
    }

    static ByteBuffer encodeCellName(TableMetadata metadata,
                                     ClusteringPrefix clustering,
                                     ByteBuffer columnName,
                                     ByteBuffer collectionElement)
    {
        boolean isStatic = clustering == Clustering.STATIC_CLUSTERING;

        if (!TableMetadata.Flag.isCompound(metadata.flags))
        {
            if (isStatic)
            {
                return columnName;
            }

            assert clustering.size() == 1 : "Expected clustering size to be 1, but was " + clustering.size();
            return clustering.bufferAt(0);
        }

        // We use comparator.size() rather than clustering.size() because of static clusterings
        int clusteringSize = metadata.comparator.size();
        int size = clusteringSize + (TableMetadata.Flag.isDense(metadata.flags) ? 0 : 1)
                   + (collectionElement == null ? 0 : 1);
        if (TableMetadata.Flag.isSuper(metadata.flags))
        {
            size = clusteringSize + 1;
        }

        ByteBuffer[] values = new ByteBuffer[size];
        for (int index = 0; index < clusteringSize; index++)
        {
            if (isStatic)
            {
                values[index] = ByteBufferUtil.EMPTY_BYTE_BUFFER;
                continue;
            }

            ByteBuffer value = clustering.bufferAt(index);
            // We can have null (only for dense compound tables for backward compatibility reasons),
            // but that means we're done and should stop there as far as building the composite is concerned
            if (value == null)
            {
                return CompositeType.build(ByteBufferAccessor.instance, Arrays.copyOfRange(values, 0, index));
            }

            values[index] = value;
        }

        if (TableMetadata.Flag.isSuper(metadata.flags))
        {
            // We need to set the "column" (in thrift terms) name, i.e. the value corresponding to the subcomparator.
            // What it is depends on whether this is a cell for a declared "static" column
            // or a "dynamic" column part of the super-column internal map.
            assert columnName != null;  // This should never be null for supercolumns, see decodeForSuperColumn() above
            values[clusteringSize] = columnName.equals(SUPER_COLUMN_MAP_COLUMN)
                                     ? collectionElement
                                     : columnName;
        }
        else
        {
            if (!TableMetadata.Flag.isDense(metadata.flags))
            {
                values[clusteringSize] = columnName;
            }
            if (collectionElement != null)
            {
                values[clusteringSize + 1] = collectionElement;
            }
        }

        return CompositeType.build(ByteBufferAccessor.instance, isStatic, values);
    }

    @Nullable
    public static TokenRange tokenRangeFromIndex(@NotNull TableMetadata tableMetadata,
                                                 @NotNull SSTable sstable) throws IOException
    {
        Pair<DecoratedKey, DecoratedKey> firstLastKeys = keysFromIndex(tableMetadata, sstable);
        Preconditions.checkNotNull(firstLastKeys, "No first and last keys read from index of %s", sstable.getDataFileName());
        return TokenRange.closed(tokenToBigInteger(firstLastKeys.getLeft().getToken()),
                                 tokenToBigInteger(firstLastKeys.getRight().getToken()));
    }

    @Nullable
    public static Pair<DecoratedKey, DecoratedKey> keysFromIndex(@NotNull TableMetadata metadata,
                                                                 @NotNull SSTable ssTable) throws IOException
    {
        return keysFromIndex(metadata.partitioner, ssTable);
    }

    @Nullable
    public static Pair<DecoratedKey, DecoratedKey> keysFromIndex(@NotNull IPartitioner partitioner,
                                                                 @NotNull SSTable ssTable) throws IOException
    {
        try (InputStream primaryIndex = ssTable.openPrimaryIndexStream())
        {
            if (primaryIndex != null)
            {
                if (ssTable.isBigFormat())
                {
                    Pair<ByteBuffer, ByteBuffer> keys = primaryIndexReadFirstAndLastKey(primaryIndex);
                    return Pair.of(partitioner.decorateKey(keys.left), partitioner.decorateKey(keys.right));
                }
                else
                {
                    File file = new File(ssTable.getDataFileName());
                    BufferingInputStream<?> bis = (BufferingInputStream<?>) primaryIndex;
                    long size = ssTable.length(FileType.PARTITIONS_INDEX);
                    try (ReadOnlyInputStreamFileChannel fileChannel = new ReadOnlyInputStreamFileChannel(bis, size);
                         ChannelProxy proxy = new ChannelProxy(file, fileChannel);
                         FileHandle fileHandle = new FileHandle.Builder(file).complete(f -> proxy);
                         PartitionIndex partitionIndex = PartitionIndex.load(fileHandle, partitioner, false))
                    {
                        return Pair.of(partitionIndex.firstKey(), partitionIndex.lastKey());
                    }
                }
            }
        }
        return null;
    }

    public static boolean anyFilterKeyInIndex(@NotNull SSTable ssTable,
                                              @NotNull TableMetadata metadata,
                                              @NotNull Descriptor descriptor,
                                              @NotNull List<PartitionKeyFilter> filters) throws IOException
    {
        if (filters.isEmpty())
        {
            return false;
        }

        if (ssTable.isBtiFormat())
        {
            return BtiReaderUtils.primaryIndexContainsAnyKey(ssTable, metadata, descriptor, filters);
        }

        try (InputStream primaryIndex = ssTable.openPrimaryIndexStream())
        {
            if (primaryIndex != null)
            {
                return primaryIndexContainsAnyKey(primaryIndex, filters);
            }
        }

        return true; // could not read primary index, so to be safe assume it contains the keys
    }

    public static Map<MetadataType, MetadataComponent> deserializeStatsMetadata(String keyspace,
                                                                                String table,
                                                                                SSTable ssTable,
                                                                                EnumSet<MetadataType> selectedTypes) throws IOException
    {
        return deserializeStatsMetadata(ssTable, selectedTypes, constructDescriptor(keyspace, table, ssTable));
    }

    public static Map<MetadataType, MetadataComponent> deserializeStatsMetadata(SSTable ssTable,
                                                                                Descriptor descriptor) throws IOException
    {
        return deserializeStatsMetadata(ssTable,
                                        EnumSet.of(MetadataType.VALIDATION, MetadataType.STATS, MetadataType.HEADER),
                                        descriptor);
    }

    public static Map<MetadataType, MetadataComponent> deserializeStatsMetadata(SSTable ssTable,
                                                                                EnumSet<MetadataType> selectedTypes,
                                                                                Descriptor descriptor) throws IOException
    {
        try (InputStream statsStream = ssTable.openStatsStream())
        {
            return deserializeStatsMetadata(statsStream,
                                            selectedTypes,
                                            descriptor);
        }
    }

    /**
     * Deserialize Statistics.db file to pull out metadata components needed for SSTable deserialization
     *
     * @param is            input stream for Statistics.db file
     * @param selectedTypes enum of MetadataType to deserialize
     * @param descriptor    SSTable file descriptor
     * @return map of MetadataComponent for each requested MetadataType
     * @throws IOException
     */
    static Map<MetadataType, MetadataComponent> deserializeStatsMetadata(InputStream is,
                                                                         EnumSet<MetadataType> selectedTypes,
                                                                         Descriptor descriptor) throws IOException
    {
        DataInputStream in = new DataInputStreamPlus(is);
        boolean isChecksummed = descriptor.version.hasMetadataChecksum();
        CRC32 crc = new CRC32();

        int count = in.readInt();
        updateChecksumInt(crc, count);
        maybeValidateChecksum(crc, in, descriptor);

        int[] ordinals = new int[count];
        int[] offsets = new int[count];
        int[] lengths = new int[count];

        for (int index = 0; index < count; index++)
        {
            ordinals[index] = in.readInt();
            updateChecksumInt(crc, ordinals[index]);

            offsets[index] = in.readInt();
            updateChecksumInt(crc, offsets[index]);
        }
        maybeValidateChecksum(crc, in, descriptor);

        for (int index = 0; index < count - 1; index++)
        {
            lengths[index] = offsets[index + 1] - offsets[index];
        }

        MetadataType[] allMetadataTypes = MetadataType.values();
        Map<MetadataType, MetadataComponent> components = new EnumMap<>(MetadataType.class);
        for (int index = 0; index < count - 1; index++)
        {
            MetadataType type = allMetadataTypes[ordinals[index]];

            if (!selectedTypes.contains(type))
            {
                in.skipBytes(lengths[index]);
                continue;
            }

            byte[] bytes = new byte[isChecksummed ? lengths[index] - CHECKSUM_LENGTH : lengths[index]];
            in.readFully(bytes);

            crc.reset();
            crc.update(bytes);
            maybeValidateChecksum(crc, in, descriptor);

            components.put(type, deserializeMetadataComponent(descriptor.version, bytes, type));
        }

        MetadataType type = allMetadataTypes[ordinals[count - 1]];
        if (!selectedTypes.contains(type))
        {
            return components;
        }

        // We do not have in.bytesRemaining() (as in FileDataInput),
        // so need to read remaining bytes to get final component
        byte[] remainingBytes = ByteBufferUtils.readRemainingBytes(in, 256);
        byte[] bytes;
        if (descriptor.version.hasMetadataChecksum())
        {
            ByteBuffer buffer = ByteBuffer.wrap(remainingBytes);
            int length = buffer.remaining() - 4;
            bytes = new byte[length];
            buffer.get(bytes, 0, length);
            crc.reset();
            crc.update(bytes);
            validateChecksum(crc, buffer.getInt(), descriptor);
        }
        else
        {
            bytes = remainingBytes;
        }

        components.put(type, deserializeMetadataComponent(descriptor.version, bytes, type));

        return components;
    }

    private static void maybeValidateChecksum(CRC32 crc, DataInputStream in, Descriptor descriptor) throws IOException
    {
        if (descriptor.version.hasMetadataChecksum())
        {
            validateChecksum(crc, in.readInt(), descriptor);
        }
    }

    private static void validateChecksum(CRC32 crc, int expectedChecksum, Descriptor descriptor)
    {
        int actualChecksum = (int) crc.getValue();

        if (actualChecksum != expectedChecksum)
        {
            String filename = descriptor.fileFor(SSTableFormat.Components.STATS).name();
            throw new CorruptSSTableException(new IOException("Checksums do not match for " + filename), filename);
        }
    }

    private static MetadataComponent deserializeValidationMetaData(@NotNull DataInputBuffer in) throws IOException
    {
        return new ValidationMetadata(in.readUTF(), in.readDouble());
    }

    private static MetadataComponent deserializeMetadataComponent(@NotNull Version version,
                                                                  @NotNull byte[] buffer,
                                                                  @NotNull MetadataType type) throws IOException
    {
        DataInputBuffer in = new DataInputBuffer(buffer);
        if (type == MetadataType.HEADER)
        {
            return deserializeSerializationHeader(in);
        }
        else if (type == MetadataType.VALIDATION)
        {
            return deserializeValidationMetaData(in);
        }
        return type.serializer.deserialize(version, in);
    }

    private static MetadataComponent deserializeSerializationHeader(@NotNull DataInputBuffer in) throws IOException
    {
        // We need to deserialize data type class names using shaded package names
        EncodingStats stats = EncodingStats.serializer.deserialize(in);
        AbstractType<?> keyType = readType(in);
        int size = (int) in.readUnsignedVInt();
        List<AbstractType<?>> clusteringTypes = new ArrayList<>(size);

        for (int index = 0; index < size; ++index)
        {
            clusteringTypes.add(readType(in));
        }

        Map<ByteBuffer, AbstractType<?>> staticColumns = new LinkedHashMap<>();
        Map<ByteBuffer, AbstractType<?>> regularColumns = new LinkedHashMap<>();
        readColumnsWithType(in, staticColumns);
        readColumnsWithType(in, regularColumns);

        try
        {
            // TODO: We should expose this code in Cassandra to make it easier to do this with unit tests in Cassandra
            return (SerializationHeader.Component) SERIALIZATION_HEADER.newInstance(keyType,
                                                                                    clusteringTypes,
                                                                                    staticColumns,
                                                                                    regularColumns,
                                                                                    stats);
        }
        catch (InstantiationException | IllegalAccessException | InvocationTargetException exception)
        {
            throw new RuntimeException(exception);
        }
    }

    private static void readColumnsWithType(@NotNull DataInputPlus in,
                                            @NotNull Map<ByteBuffer, AbstractType<?>> typeMap) throws IOException
    {
        int length = (int) in.readUnsignedVInt();
        for (int index = 0; index < length; index++)
        {
            ByteBuffer name = ByteBufferUtil.readWithVIntLength(in);
            typeMap.put(name, readType(in));
        }
    }

    private static AbstractType<?> readType(@NotNull DataInputPlus in) throws IOException
    {
        return TypeParser.parse(UTF8Type.instance.compose(ByteBufferUtil.readWithVIntLength(in)));
    }

    public static Pair<ByteBuffer, ByteBuffer> primaryIndexReadFirstAndLastKey(@NotNull InputStream primaryIndex) throws IOException
    {
        ByteBuffer[] firstAndLast = new ByteBuffer[]{null, null};
        readPrimaryIndex(primaryIndex, (buffer) -> {
            if (firstAndLast[0] == null)
            {
                firstAndLast[0] = buffer;
            }
            firstAndLast[1] = buffer;
            return false; // never exit early
        });
        return Pair.of(firstAndLast[0], firstAndLast[1]);
    }

    /**
     * Reads primary Index.db file returning true and exiting early if it contains any of the PartitionKeyFilter
     *
     * @param primaryIndex input stream for Index.db file
     * @param filters      list of filters to search for
     * @return true if Index.db file contains any of the keys
     * @throws IOException
     */
    public static boolean primaryIndexContainsAnyKey(@NotNull InputStream primaryIndex,
                                                     @NotNull List<PartitionKeyFilter> filters) throws IOException
    {
        final boolean[] result = new boolean[]{false};
        readPrimaryIndex(primaryIndex, (buffer) -> {
            boolean anyMatch = filters.stream().anyMatch(filter -> filter.matches(buffer));
            if (anyMatch)
            {
                result[0] = true;
                return true; // exit early, we found at least one key
            }
            return false;
        });
        return result[0];
    }

    /**
     * Read primary Index.db file
     *
     * @param primaryIndex input stream for Index.db file
     * @param tracker      tracker that consumes each key buffer and returns true if can exit early, otherwise continues to read primary index
     * @throws IOException
     */
    public static void readPrimaryIndex(@NotNull InputStream primaryIndex,
                                        @NotNull Function<ByteBuffer, Boolean> tracker) throws IOException
    {
        try (DataInputStream dis = new DataInputStream(primaryIndex))
        {
            try
            {
                while (true)
                {
                    int length = dis.readUnsignedShort();
                    byte[] array = new byte[length];
                    dis.readFully(array);
                    ByteBuffer buffer = ByteBuffer.wrap(array);
                    if (tracker.apply(buffer))
                    {
                        // exit early if tracker returns true
                        return;
                    }

                    // Read position and skip promoted index
                    skipRowIndexEntry(dis);
                }
            }
            catch (EOFException ignored)
            {
            }
        }
    }

    static void skipRowIndexEntry(DataInputStream dis) throws IOException
    {
        readPosition(dis);
        skipPromotedIndex(dis);
    }

    static int vIntSize(long value)
    {
        return VIntCoding.computeUnsignedVIntSize(value);
    }

    static void writePosition(long value, ByteBuffer buffer)
    {
        VIntCoding.writeUnsignedVInt(value, buffer);
    }

    static long readPosition(DataInputStream dis) throws IOException
    {
        return VIntCoding.readUnsignedVInt(dis);
    }

    /**
     * @return the total bytes skipped
     */
    public static int skipPromotedIndex(DataInputStream dis) throws IOException
    {
        final long val = VIntCoding.readUnsignedVInt(dis);
        final int size = (int) val;
        if (size > 0)
        {
            ByteBufferUtils.skipBytesFully(dis, size);
        }
        return Math.max(size, 0) + VIntCoding.computeUnsignedVIntSize(val);
    }

    static List<PartitionKeyFilter> filterKeyInBloomFilter(
    @NotNull SSTable ssTable,
    @NotNull IPartitioner partitioner,
    Descriptor descriptor,
    @NotNull List<PartitionKeyFilter> partitionKeyFilters) throws IOException
    {
        try
        {
            BloomFilter bloomFilter = SSTableCache.INSTANCE.bloomFilter(ssTable, descriptor);
            return partitionKeyFilters.stream()
                                      .filter(filter -> bloomFilter.isPresent(partitioner.decorateKey(filter.key())))
                                      .collect(Collectors.toList());
        }
        catch (Exception exception)
        {
            if (exception instanceof FileNotFoundException)
            {
                return partitionKeyFilters;
            }
            throw exception;
        }
    }

    public static BloomFilter readFilter(@NotNull SSTable ssTable, Descriptor descriptor) throws IOException
    {
        return readFilter(ssTable, descriptor.version.hasOldBfFormat());
    }

    public static BloomFilter readFilter(@NotNull SSTable ssTable, boolean hasOldBfFormat) throws IOException
    {
        try (InputStream filterStream = ssTable.openFilterStream())
        {
            if (filterStream != null)
            {
                int bufferSize = inputStreamBufferSize(filterStream);
                try (DataInputStream dis = new DataInputStream(filterStream);
                     DataInputPlus.DataInputStreamPlus in = new RebufferingChannelInputStream(dis, bufferSize))
                {
                    return BloomFilterSerializer.forVersion(hasOldBfFormat).deserialize(in);
                }
            }
        }
        throw new FileNotFoundException();
    }

    /**
     * If known, return internal buffer size of given input stream, {@code -1} otherwise.
     */
    public static int inputStreamBufferSize(InputStream inputStream)
    {
        if (inputStream instanceof BufferingInputStream<?>)
        {
            BufferingInputStream<?> bis = (BufferingInputStream<?>) inputStream;
            return Ints.checkedCast(bis.chunkBufferSize());
        }
        return -1;
    }
}
