Utilizzare un modello TensorFlow Lite personalizzato su Android

Se la tua app utilizza modelli TensorFlow Lite, puoi utilizzare Firebase ML per eseguire il deployment dei modelli. Di eseguendo il deployment di modelli con Firebase, puoi ridurre le dimensioni l'app e aggiorna i modelli ML dell'app senza rilasciare una nuova versione la tua app. Inoltre, con Remote Config e A/B Testing, puoi eseguire dinamicamente per distribuire modelli diversi a insiemi di utenti differenti.

Modelli TensorFlow Lite

I modelli TensorFlow Lite sono modelli di ML ottimizzati per l'esecuzione sui dispositivi mobili dispositivi mobili. Per ottenere un modello TensorFlow Lite:

Prima di iniziare

  1. Se non l'hai già fatto, aggiungi Firebase al tuo progetto Android.
  2. Nel file Gradle del modulo (a livello di app) (di solito <project>/<app-module>/build.gradle.kts o <project>/<app-module>/build.gradle), aggiungi la dipendenza per la libreria del downloader di modelli Firebase ML per Android. Ti consigliamo di utilizzare Firebase Android BoM per controllare il controllo delle versioni delle librerie.

    Inoltre, durante la configurazione del downloader modello Firebase ML, devi aggiungere il metodo TensorFlow Lite SDK alla tua app.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:33.2.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Se utilizzi Firebase Android BoM, la tua app utilizzerà sempre versioni compatibili delle librerie Firebase Android.

    (Alternativa)  Aggiungi le dipendenze della libreria Firebase senza utilizzare il file BoM

    Se scegli di non utilizzare Firebase BoM, devi specificare ogni versione della libreria Firebase nella relativa riga di dipendenza.

    Tieni presente che se utilizzi più librerie Firebase nella tua app, ti consigliamo consiglia di utilizzare BoM per gestire le versioni della libreria, in modo da garantire che tutte le versioni siano compatibili.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:25.0.0")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
    Cerchi un modulo della libreria specifico per Kotlin? A partire da Ottobre 2023 (Firebase BoM 32.5.0), gli sviluppatori Kotlin e Java possono dipendono dal modulo principale della libreria (per i dettagli, consulta Domande frequenti su questa iniziativa).
  3. Nel file manifest dell'app, dichiara che è necessaria l'autorizzazione INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

1. Esegui il deployment del modello

Esegui il deployment dei tuoi modelli TensorFlow personalizzati utilizzando la console Firebase oppure gli SDK Firebase Admin e Node.js. Consulta Distribuire e gestire modelli personalizzati.

Dopo aver aggiunto un modello personalizzato al tuo progetto Firebase, puoi fare riferimento alla tuo modello nelle tue app utilizzando il nome specificato. Puoi eseguire il deployment in qualsiasi momento un nuovo modello TensorFlow Lite e scaricarlo sul sito web degli utenti dispositivi per chiamata al numero getModel() (vedi sotto).

2. Scarica il modello sul dispositivo e inizializza un interprete TensorFlow Lite

Per utilizzare il modello TensorFlow Lite nella tua app, usa prima l'SDK Firebase ML per scaricare l'ultima versione del modello sul dispositivo. Dopodiché, crea un interprete TensorFlow Lite con il modello.

Per avviare il download del modello, chiama il metodo getModel() del downloader del modello, specificando il nome assegnato al modello al momento del caricamento, se vuoi scaricare sempre il modello più recente e le condizioni in cui vuoi consentire il download.

Puoi scegliere fra tre tipi di comportamento di download:

Tipo di download Descrizione
MODELLO_LOCALE Recupera il modello locale dal dispositivo. Se non è disponibile alcun modello locale, si comporta come LATEST_MODEL. Usa questa tipo di download se non ti interessa controllare la disponibilità di aggiornamenti del modello. Ad esempio, utilizzi Remote Config per recuperare i nomi dei modelli e carichi sempre i modelli con nuovi nomi (opzione consigliata).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Recupera il modello locale dal dispositivo e iniziare ad aggiornare il modello in background. Se non è disponibile alcun modello locale, il valore si comporta come LATEST_MODEL.
ULTIMO_MODELLO Acquista il modello più recente. Se il modello locale è all'ultima versione, restituisce il token un modello di machine learning. Altrimenti, scarica l'ultima versione un modello di machine learning. Questo comportamento bloccherà fino a quando viene scaricata l'ultima versione (non consigliato). Utilizza questo comportamento solo nei casi in cui hai bisogno esplicitamente della versione più recente.

È necessario disattivare le funzionalità correlate al modello, ad esempio nasconde parte dell'interfaccia utente, fino a quando non confermi il download del modello.

Kotlin KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Molte app avviano l'attività di download nel codice di inizializzazione, ma puoi farlo quindi in qualsiasi momento prima di dover usare il modello.

3. Esegui l'inferenza sui dati di input

Ottieni le forme di input e output del tuo modello

L'interprete di modelli TensorFlow Lite prende come input e lo produce come output uno o più array multidimensionali. Questi array contengono byte, int, long o float e i relativi valori. Prima di poter passare i dati a un modello o utilizzare il suo risultato, devi conoscere il numero e le dimensioni ("forma") degli array utilizzati dal modello.

Se il modello è stato creato da te o se il formato di input e output del modello è potrebbero essere già state documentate, potresti già avere queste informazioni. Se non conosci la forma e il tipo di dati di input e output del modello, puoi utilizzare Interprete TensorFlow Lite per ispezionare il modello. Ad esempio:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Output di esempio:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Esegui l'interprete

Dopo aver determinato il formato di input e output del modello, ottieni il di input ed eseguire le trasformazioni necessarie per ottenere un input della forma giusta per il tuo modello.

Ad esempio, se hai un modello di classificazione delle immagini con una forma di input di valori di tipo [1 224 224 3] a virgola mobile, puoi generare un input ByteBuffer da un oggetto Bitmap come mostrato nell'esempio seguente:

Kotlin KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y  ) {
    for (int x = 0; x < 224; x  ) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Quindi, alloca ByteBuffer sufficientemente grande da contenere l'output del modello passare il buffer di input e di output al prompt Metodo run(). Ad esempio, per una forma di output con rappresentazione in virgola mobile [1 1000] valori:

Kotlin KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

La modalità di utilizzo dell'output dipende dal modello utilizzato.

Ad esempio, se stai eseguendo la classificazione, come passaggio successivo potresti mappare gli indici del risultato alle etichette che rappresentano:

Kotlin KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i  ) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Appendice: Sicurezza del modello

Indipendentemente da come rendi disponibili i tuoi modelli TensorFlow Lite Firebase ML, Firebase ML li archivia nel formato protobuf serializzato standard in archiviazione locale.

In teoria, questo significa che chiunque può copiare il modello. Tuttavia, in pratica, la maggior parte dei modelli è così specifica per l'applicazione e offuscata dalle ottimizzazioni che il rischio è simile a quello dei concorrenti che smontano e riutilizzano il codice. Tuttavia, è necessario essere consapevoli di questo rischio prima di utilizzare un modello personalizzato nella tua app.

Sul livello API Android 21 (Lollipop) e versioni successive, il modello viene scaricato in un che è esclusi dal backup automatico.

Sul livello API Android 20 e versioni precedenti, il modello viene scaricato in una directory denominato com.google.firebase.ml.custom.models in privato dell'app memoria interna. Se hai attivato il backup dei file utilizzando BackupAgent, puoi scegliere di escludere questa directory.