CHARacter-awaRE Diffusion: Multilingual Character-Aware Encoders for Font-Aware Diffusers That Can Actually Spell
Tired of text-to-image models that can't spell or deal with fonts and typography correctly ? The secret seems to be in the use of multilingual, tokenization-free, character-aware transformer encoders such as ByT5 and CANINE-c.
AS part of the Hugging Face JAX Diffuser Sprint, we will replace CLIP's tokenizer and encoder with ByT5's in the HF's JAX/FLAX text-to-image pre-training code and run it on the sponsored TPU ressources provided by Google for the event.
More specifically, here are the main tasks we will try to accomplish during the sprint:
-
Pre-training dataset preparation: we are NOT going to train on
lambdalabs/pokemon-blip-captions
. So what is it going to be, what are the options? Anything in here or here takes your fancy? Or maybe DiffusionDB? Or a savant mix of many datasets? We probably will need to combine many datasets as we are looking to cover these requirements:- We need samples for which there is text in the scene that is explicitely specified in the caption and the priority is to do that in full scene photos. If we can't find enough, we will integrate more specialized datasets for OCR;
- Approximately the same language distribution as ByT5, but also include indonesian (not in ByT5) to see how character-awareness works when text in the prompt is specified in a language. We need to build testing facilities around the languages that are spoken by team members and friends: indonesian, japanese, french, amharic, arabic, norwegian, swedish, hindi, urdu and english.
We shoud use the Hugging Face Datasets library as much as possible since it supports JAX out of the box. For simplicity's sake we will limit us to concatenated Hugging Face datasets such as LAION2B EN, MULTI and NOLANG. We shall, however pre-load, pre-process and cache the dataset on disk before training on it.
-
Improvements to the original code:
Make sure we can run the original code as-is on the TPU VM.- Audit and optimize the code for the Google Cloud TPU v4-8 VM:
jnp
(instead of np)jit
,grad
,vmap
,pmap
,pjit
everywhere! And we should make sure we do not miss any optimization made in the sprint code either. - Instrumentation for TPU remote monitoring with Open Telemetry, TensorBoard, Perfetto, Weights & Biases and JAX's own profiler.
- Implement checkpoint milestone snapshot uploading to cloud storage: we need to be able to download the model for local inference benchmarking to make sure we are on the right track. There seems to be rudimentary checkpoint support in the original code.
No time for politics. NSFW filtering will be turned off. So we getFlaxStableDiffusionSafetyChecker
out of the way.
-
Replace CLIP with ByT5 in original code:
ReplacingMerged. Needs testing.CLIPTokenizer
withByT5Tokenizer
. Since this will run on the CPUs, there is no need for JAX/FLAX unless there is hope for huge performance improvements. This should be trivial.ReplacingMerged. Needs testing.FlaxCLIPTextModel
withFlaxT5EncoderModel
. This might be almost as easy as replacing the tokenizer.RewriteDone. Needs testing.CLIPImageProcessor
for ByT5. This is still under investigation. It's unclear how hard it will be.RAdaptDone. Needs testing.FlaxAutoencoderKL
andFlaxUNet2DConditionModel
for ByT5 if necessary.Break down the main pretraining loop into many functions in different source files for readability and easier maintenance.
Secondly, we will integrate to the above a Hugging-Face JAX/FLAX ControlNet implementation for better typographic control over the generated images. To the orthographically-enanced SD above and as per Peter von Platen's suggestion, we also introduce the idea a typographic ControlNet trained on an synthetic dataset of images paired with multilingual specifications of the textual content, font taxonomy, weight, kerning, leading, slant and any other typographic attribute supported by the CSS3 Text, Fonts and Writing Modes modules, as implemented by the latest version of Chromium.