Keras를 사용하여 Gemma를 사용한 분산 조정

ai.google.dev에서 보기 Google Colab에서 실행 Kaggle에서 실행하기 Vertex AI에서 열기 GitHub에서 소스 보기

개요

Gemma는 Google Gemini 모델을 만드는 데 사용된 연구와 기술을 바탕으로 구축된 최첨단 경량 개방형 모델 제품군입니다. Gemma는 특정 필요에 맞게 더욱 미세하게 조정할 수 있습니다. 그러나 Gemma와 같은 대규모 언어 모델은 크기가 매우 클 수 있으며, 일부는 미세 조정을 위한 단일 액셀러레이터에 맞지 않을 수 있습니다. 이 경우 일반적으로 다음과 같은 두 가지 방법으로 미세 조정할 수 있습니다.

  1. Parameter-Efficient Fine-Tuning (매개변수 미세 조정)은 일부 충실도를 희생하여 유효 모델 크기를 축소하는 것을 목표로 합니다. LoRA가 이 카테고리에 속합니다. LoRA를 사용하여 Keras에서 Gemma 모델 세부 조정 튜토리얼에서는 단일 GPU에서 KerasNLP를 사용하여 LoRA로 Gemma 2B 모델 gemma_2b_en를 미세 조정하는 방법을 보여줍니다.
  2. 모델 동시 로드를 통한 전체 매개변수 미세 조정 모델 동시 로드는 단일 모델의 가중치를 여러 기기에 분산하고 수평 확장을 지원합니다. 분산 학습에 대한 자세한 내용은 이 Keras 가이드를 참조하세요.

이 튜토리얼에서는 JAX 백엔드와 함께 Keras를 사용하여 LoRA로 Gemma 7B 모델을 미세 조정하고 Google의 Tensor Processing Unit (TPU)에서 모델-패럴리즘 분산 학습을 사용하는 방법을 안내합니다. 느리지만 더 정확한 전체 매개변수 조정을 위해 이 튜토리얼에서 LoRA를 사용 중지할 수 있습니다.

가속기 사용

기술적으로 이 튜토리얼에서는 TPU 또는 GPU를 사용할 수 있습니다.

TPU 환경 참고사항

Google에는 TPU를 제공하는 3가지 제품이 있습니다.

  • Colab은 TPU v2를 무료로 제공하므로 이 튜토리얼을 진행하기에 충분합니다.
  • Kaggle은 TPU v3를 무료로 제공하며 이 튜토리얼에서도 작동합니다.
  • Cloud TPU는 TPU v3 및 최신 세대를 제공합니다. 설정 방법은 다음과 같습니다.
    1. TPU VM을 만듭니다.
    2. 원하는 Jupyter 서버 포트에 SSH 포트 전달을 설정합니다.
    3. Jupyter를 설치하고 TPU VM에서 시작한 다음 '로컬 런타임에 연결'을 통해 Colab에 연결하세요.

다중 GPU 설정 참고사항

이 튜토리얼에서는 TPU 사용 사례를 중점적으로 다루지만 다중 GPU 머신을 사용하는 경우 사용자의 필요에 맞게 쉽게 조정할 수 있습니다.

Colab을 통해 작업하고 싶다면 '커스텀 GCE VM에 연결'을 통해 Colab용 다중 GPU VM을 직접 프로비저닝할 수도 있습니다. 찾을 수 있습니다

여기에서는 Kaggle의 무료 TPU를 사용하는 데 중점을 두겠습니다.

시작하기 전에

Kaggle 사용자 인증 정보

Gemma 모델은 Kaggle에서 호스팅됩니다. Gemma를 사용하려면 Kaggle에서 액세스 권한을 요청하세요.

  • kaggle.com에서 로그인 또는 등록
  • Gemma 모델 카드를 열고 '액세스 요청'을 선택합니다.
  • 동의 양식을 작성하고 이용약관에 동의합니다.

그런 다음 Kaggle API를 사용하기 위해 API 토큰을 만듭니다.

  • 캐글 설정을 엽니다.
  • '새 토큰 만들기'를 선택합니다.
  • kaggle.json 파일이 다운로드됩니다. 여기에는 Kaggle 사용자 인증 정보가 포함되어 있습니다.

다음 셀을 실행하고 메시지가 표시되면 Kaggle 사용자 인증 정보를 입력합니다.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

kagglehub.login()이 작동하지 않는 경우 사용 중인 환경에서 KAGGLE_USERNAME 및 KAGGLE_KEY를 설정하는 방법도 있습니다.

설치

Gemma 모델을 사용하여 Keras와 KerasNLP를 설치합니다.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Keras JAX 백엔드 설정

JAX를 가져오고 TPU에서 상태 검사를 실행합니다. Kaggle은 8개의 TPU 코어와 각각 16GB의 메모리를 갖춘 TPUv3-8 기기를 제공합니다.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

모델 로드

import keras
import keras_nlp

NVIDIA GPU의 혼합 정밀도 학습에 관한 참고사항

NVIDIA GPU에서 학습할 때 혼합 정밀도 (keras.mixed_precision.set_global_policy('mixed_bfloat16'))를 사용하면 학습 품질에 미치는 영향을 최소화하면서 학습 속도를 높일 수 있습니다. 대부분의 경우 메모리와 시간이 모두 절약되므로 혼합 정밀도를 사용 설정하는 것이 좋습니다. 그러나 작은 배치 크기에서는 메모리 사용량을 1.5배 늘릴 수 있습니다 (가중치는 절반 정밀도와 전체 정밀도로 두 번 로드됨).

추론의 경우 혼합 정밀도는 적용되지 않지만 절반 정밀도 (keras.config.set_floatx("bfloat16"))는 작동하고 메모리를 절약합니다.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

TPU 전반에 분산된 가중치와 텐서로 모델을 로드하려면 먼저 새 DeviceMesh를 만듭니다. DeviceMesh는 분산 계산을 위해 구성된 하드웨어 기기 모음을 나타내며 통합 배포 API의 일부로 Keras 3에 도입되었습니다.

Distribution API는 데이터 및 모델 동시 로드를 지원하므로 여러 가속기 및 호스트에서 딥 러닝 모델을 효율적으로 확장할 수 있습니다. 기본 프레임워크 (예: JAX)를 활용하여 단일 프로그램, 다중 데이터 (SPMD) 확장이라는 절차를 통해 샤딩 지시문에 따라 프로그램과 텐서를 배포합니다. 자세한 내용은 새로운 Keras 3 배포 API 가이드를 참고하세요.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

배포 API의 LayoutMap는 텐서 경로를 일치시키기 위해 정규식처럼 처리되는 문자열 키(예: 아래의 token_embedding/embeddings)를 사용하여 가중치와 텐서를 샤딩하거나 복제하는 방법을 지정합니다. 일치하는 텐서는 모델 차원 (TPU 8개)으로 샤딩됩니다. 나머지는 완전히 복제됩니다

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel를 사용하면 DeviceMesh의 모든 기기에서 모델 가중치 또는 활성화 텐서를 샤딩할 수 있습니다. 이 경우 Gemma 7B 모델 가중치 중 일부는 위에 정의된 layout_map에 따라 8개의 TPU 칩에서 샤딩됩니다. 이제 모델을 분산된 방식으로 로드합니다.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

이제 모델의 파티션을 올바르게 나누었는지 확인합니다. decoder_block_1를 예로 들어보겠습니다.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

미세 조정 전 추론

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

이 모델이 90년대 최고의 코미디 영화 목록을 생성합니다. 이제 Gemma 모델을 미세 조정하여 출력 스타일을 변경합니다.

IMDB로 미세 조정

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

낮은 순위 조정 (LoRA)을 사용하여 미세 조정을 수행합니다. LoRA는 모델의 전체 가중치를 고정하고 모델에 더 적은 수의 새로운 학습 가능 가중치를 삽입하여 다운스트림 작업의 학습 가능한 매개변수 수를 크게 줄이는 미세 조정 기법입니다. 기본적으로 LoRA는 큰 전체 가중치 행렬을 2개의 작은 하위 행렬 AxB로 다시 매개변수화하여 학습합니다. 이 기법을 통해 학습 속도가 훨씬 빠르고 메모리 효율성도 높아집니다.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

LoRA를 사용 설정하면 학습 가능한 매개변수 수가 70억 개에서 1, 100만 개로 크게 줄어듭니다.

미세 조정 후 추론

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

미세 조정 후 모델은 영화 리뷰 스타일을 학습했으며 이제 90년대 코미디 영화 맥락에서 해당 스타일로 출력을 생성합니다.

다음 단계

이 튜토리얼에서는 KerasNLP JAX 백엔드를 사용하여 강력한 TPU에 분산된 방식으로 IMDb 데이터 세트의 Gemma 모델을 미세 조정하는 방법을 배웠습니다. 다음은 추가로 알아두어야 할 사항의 몇 가지 제안사항입니다.