Skip to content

Commit

Permalink
Add a driver.f90
Browse files Browse the repository at this point in the history
  • Loading branch information
certik committed Mar 21, 2023
1 parent e7d4d97 commit 6bf685f
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 48,7 @@ enable_testing()
set(SRC
gpt2.f90
tokenizer.f90
driver.f90
)
if (FASTGPT_BLAS STREQUAL "Accelerate")
list(APPEND SRC
Expand Down
177 changes: 177 additions & 0 deletions driver.f90
Original file line number Diff line number Diff line change
@@ -0,0 1,177 @@
module driver
use gpt2_mod, only: generate
use tokenizer, only: encode, decode
use omp, only: omp_get_wtime
implicit none

integer, parameter :: sp = kind(0.0)
integer, parameter :: dp = kind(0.d0)

contains

subroutine gpt2_driver()
integer :: n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, &
n_tokens_to_generate, n_decoder_idx, n_decoder_txt, &
n_vocab_idx, n_vocab_txt, n_byte_encoder
integer, allocatable :: input(:), decoder_idx(:), vocab_idx(:), byte_decoder(:)
integer :: byte_encoder(0:255)
real(sp), allocatable :: wte(:,:), wpe(:,:), &
mlp_fc_w(:,:,:), mlp_fc_b(:,:), &
mlp_proj_w(:,:,:), mlp_proj_b(:,:), &
attn_w(:,:,:), attn_b(:,:), &
attn_proj_w(:,:,:), attn_proj_b(:,:), &
ln1_b(:,:), ln1_g(:,:), &
ln2_b(:,:), ln2_g(:,:), &
lnf_b(:), lnf_g(:)
character, allocatable :: decoder_txt(:), vocab_txt(:)
integer, allocatable :: output(:)
character(:), allocatable :: output_txt, input_txt
character(1024) :: input_txt2
real(dp) :: t1, t2, t1o, t2o
integer :: u, i, ios
logical :: use_cache
integer, parameter :: input_ref(*) = [36235, 39141, 18765, 1143, 326, 9061, &
561, 530, 1110, 1716, 845, 3665, 11, 475, 772, 339, 714, 407, 5967]
integer, parameter :: output_ref(*) = [703, 484, 561, 307, 1498, 284, 466, &
523, 13, 198, 198, 1, 40, 892, 326, 262, 749, 1593, 1517, 318]
namelist / input_fastGPT / n_tokens_to_generate

! Load the model
print "(a)", "Loading the model..."
call cpu_time(t1)
open(newunit=u, file="model.dat", form="unformatted", access="stream", status="old")
!read(u) model_version
! fastGPT (digits look similar to the letters they represent)
! model_version /= 0xfa51697
read(u) n_vocab, n_ctx, n_embd, n_layer, n_head, n_decoder_idx, n_decoder_txt, &
n_vocab_idx, n_vocab_txt, n_byte_encoder
allocate(wte(n_embd,n_vocab), wpe(n_embd,n_ctx), &
mlp_fc_w(4*n_embd,n_embd,n_layer), mlp_fc_b(4*n_embd,n_layer), &
mlp_proj_w(n_embd,4*n_embd,n_layer), mlp_proj_b(n_embd,n_layer), &
attn_w(3*n_embd,n_embd,n_layer), attn_b(3*n_embd,n_layer), &
attn_proj_w(n_embd,n_embd,n_layer), attn_proj_b(n_embd,n_layer), &
ln1_b(n_embd,n_layer), ln1_g(n_embd,n_layer), &
ln2_b(n_embd,n_layer), ln2_g(n_embd,n_layer), &
lnf_b(n_embd), lnf_g(n_embd), &
decoder_idx(n_decoder_idx), decoder_txt(n_decoder_txt), &
vocab_idx(n_vocab_idx), vocab_txt(n_vocab_txt))
if (n_byte_encoder /= 256) error stop "n_byte_encoder must be 256"
read(u) wte, wpe, &
mlp_fc_w, mlp_fc_b, &
mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, &
attn_proj_w, attn_proj_b, &
ln1_b, ln1_g, &
ln2_b, ln2_g, &
lnf_b, lnf_g, &
decoder_idx, decoder_txt, &
vocab_idx, vocab_txt, &
byte_encoder
close(u)
call cpu_time(t2)
print "(a,f8.3,a)", " done. Time:", t2-t1, "s"
print *
print "(a)", "Model parameters:"
print "(a,i6)", "n_vocab =", n_vocab
print "(a,i6)", "n_ctx =", n_ctx
print "(a,i6)", "n_embd =", n_embd
print "(a,i6)", "n_layer =", n_layer
print "(a,i6)", "n_head =", n_head
print *

! Compute byte_decoder:
allocate(byte_decoder(0:maxval(byte_encoder)))
byte_decoder = 0
do i = 0, size(byte_encoder)-1
byte_decoder(byte_encoder(i)) = i
end do

! Load the input
allocate(character(0) :: input_txt)
input_txt = ""
open(newunit=u, file="input", status="old")
read(u, input_fastGPT)
do
read(u, "(a)", iostat=ios) input_txt2
if (ios /= 0) exit
if (len(input_txt) > 0) input_txt = input_txt // char(10)
input_txt = input_txt // trim(input_txt2)
end do
close(u)
print "(a)", "Input text"
print "(a)", input_txt

print *
print "(a)", "Encoding: tokenizing input text into tokens (currently slow)..."
call cpu_time(t1)
input = encode(input_txt, decoder_idx, decoder_txt, vocab_idx, vocab_txt, &
byte_encoder)
call cpu_time(t2)
n_seq = size(input)
print "(a,f8.3,a)", " done. Time:", t2-t1, "s"
print *
print "(a)", "Input parameters:"
print "(a,i4)", "n_seq =", n_seq
print "(a,i4)", "n_tokens_to_generate =", n_tokens_to_generate
print *
print "(a)", "Input tokens:"
print "(1000(i6))", input
print *

if (all(input == input_ref)) then
print *, "Input tokens agree with reference results"
else
print *, "Input tokens DO NOT agree with reference results"
error stop
end if

if (n_seq n_tokens_to_generate >= n_ctx) then
print *, "The maximum sequence length of the model was surpassed."
print *, "Make the input and/or number of tokens to generate shorter."
error stop
end if

print "(a)", "Decoded input as text:"
!print "(a)", decode(input, decoder_idx, decoder_txt, byte_decoder)
allocate(character(0) :: output_txt) ! Fix GFortran warning
output_txt = decode(input, decoder_idx, decoder_txt, byte_decoder)
print "(a)", output_txt
print *

if (input_txt /= output_txt) then
error stop "The decoded input text does not agree with the input text"
end if

allocate(output(n_tokens_to_generate))
print "(a)", "Running model..."
call cpu_time(t1)
t1o = omp_get_wtime()
use_cache = .true.
output = generate(n_tokens_to_generate, n_vocab, n_ctx, size(input), n_embd, &
n_layer, n_head, &
input, &
wte, wpe, &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache, &
decoder_idx, decoder_txt, byte_decoder)
t2o = omp_get_wtime()
call cpu_time(t2)
print "(a,f8.3,a,f4.2,a)", " done. Time:", t2o-t1o, "s (", (t2-t1)/(t2o-t1o), "x)"
print *
print "(a)", "Output tokens:"
print "(1000(i6))", output
output_txt = decode(output, decoder_idx, decoder_txt, byte_decoder)
print *
print "(a)", "Decoded output as text:"
print "(a)", output_txt

if (all(output == output_ref)) then
print *, "Output tokens agree with reference results"
else
print *, "Output tokens DO NOT agree with reference results"
error stop
end if
end subroutine

end module

0 comments on commit 6bf685f

Please sign in to comment.