/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.training;

import io.skylite.core.action.search.SearchResponse;
import io.skylite.core.search.SearchHit;
import java.util.ArrayList;
import java.util.List;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.jni.JNICommons;
import org.opensearch.knn.training.TrainingDataConsumer;

public class ByteTrainingDataConsumer
extends TrainingDataConsumer {
    public ByteTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) {
        super(trainingDataAllocation);
    }

    @Override
    public void accept(List<?> byteVectors) {
        long memoryAddress = this.trainingDataAllocation.getMemoryAddress();
        memoryAddress = JNICommons.storeByteVectorData(memoryAddress, (byte[][])byteVectors.toArray((T[])new byte[0][0]), byteVectors.size());
        this.trainingDataAllocation.setMemoryAddress(memoryAddress);
    }

    @Override
    public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) {
        SearchHit[] hits = searchResponse.getHits().getHits();
        ArrayList<byte[]> vectors = new ArrayList<byte[]>();
        String[] fieldPath = fieldName.split("\\.");
        for (int vector = 0; vector < vectorsToAdd; ++vector) {
            Object fieldValue = this.extractFieldValue(hits[vector], fieldPath);
            if (!(fieldValue instanceof List)) continue;
            List fieldList = (List)fieldValue;
            byte[] byteArray = new byte[fieldList.size()];
            for (int i = 0; i < fieldList.size(); ++i) {
                byteArray[i] = ((Number)fieldList.get(i)).byteValue();
            }
            vectors.add(byteArray);
        }
        this.setTotalVectorsCountAdded(this.getTotalVectorsCountAdded() + vectors.size());
        this.accept(vectors);
    }
}

