Skip to content

Commit

Permalink
Merge pull request certik#39 from certik/tests4
Browse files Browse the repository at this point in the history
Add a test for another input text
  • Loading branch information
certik authored Mar 21, 2023
2 parents 0bfbb35 45328f8 commit 5ff9e51
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 17 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 79,15 @@ endif()

add_executable(gpt2 main.f90)
target_link_libraries(gpt2 fastgpt)

add_executable(test_basic_input tests/test_basic_input.f90)
target_link_libraries(test_basic_input fastgpt)
add_test(test_basic_input ${PROJECT_BINARY_DIR}/test_basic_input)

add_executable(test_more_inputs tests/test_more_inputs.f90)
target_link_libraries(test_more_inputs fastgpt)
add_test(test_more_inputs ${PROJECT_BINARY_DIR}/test_more_inputs)

if(NOT PROJECT_SOURCE_DIR STREQUAL PROJECT_BINARY_DIR)
# Git auto-ignore out-of-source build directory
file(GENERATE OUTPUT .gitignore CONTENT "*")
Expand Down
51 changes: 34 additions & 17 deletions driver.f90
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 9,41 @@ module driver

contains

subroutine load_input(filename, input_txt, n_tokens_to_generate)
! Load the input from a namelist `filename`
character(*), intent(in) :: filename
character(:), allocatable, intent(out) :: input_txt
integer, intent(out) :: n_tokens_to_generate
character(1024) :: input_txt2
integer :: u, ios
namelist / input_fastGPT / n_tokens_to_generate
allocate(character(0) :: input_txt)
input_txt = ""
open(newunit=u, file=filename, status="old")
read(u, input_fastGPT)
do
read(u, "(a)", iostat=ios) input_txt2
if (ios /= 0) exit
if (len(input_txt) > 0) input_txt = input_txt // char(10)
input_txt = input_txt // trim(input_txt2)
end do
close(u)
end subroutine

subroutine gpt2_driver(input, output)
integer, allocatable, intent(out) :: input(:), output(:)
character(:), allocatable :: input_txt
integer :: n_tokens_to_generate
call load_input("input", input_txt, n_tokens_to_generate)
call gpt2_driver2(input_txt, n_tokens_to_generate, input, output)
endsubroutine

subroutine gpt2_driver2(input_txt, n_tokens_to_generate, input, output)
character(*), intent(in) :: input_txt
integer, intent(in) :: n_tokens_to_generate
integer, allocatable, intent(out) :: input(:), output(:)
integer :: n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, &
n_tokens_to_generate, n_decoder_idx, n_decoder_txt, &
n_decoder_idx, n_decoder_txt, &
n_vocab_idx, n_vocab_txt, n_byte_encoder
integer, allocatable :: decoder_idx(:), vocab_idx(:), byte_decoder(:)
integer :: byte_encoder(0:255)
Expand All @@ -25,12 56,10 @@ subroutine gpt2_driver(input, output)
ln2_b(:,:), ln2_g(:,:), &
lnf_b(:), lnf_g(:)
character, allocatable :: decoder_txt(:), vocab_txt(:)
character(:), allocatable :: output_txt, input_txt
character(1024) :: input_txt2
character(:), allocatable :: output_txt
real(dp) :: t1, t2, t1o, t2o
integer :: u, i, ios
integer :: u, i
logical :: use_cache
namelist / input_fastGPT / n_tokens_to_generate

! Load the model
print "(a)", "Loading the model..."
Expand Down Expand Up @@ -82,18 111,6 @@ subroutine gpt2_driver(input, output)
byte_decoder(byte_encoder(i)) = i
end do

! Load the input
allocate(character(0) :: input_txt)
input_txt = ""
open(newunit=u, file="input", status="old")
read(u, input_fastGPT)
do
read(u, "(a)", iostat=ios) input_txt2
if (ios /= 0) exit
if (len(input_txt) > 0) input_txt = input_txt // char(10)
input_txt = input_txt // trim(input_txt2)
end do
close(u)
print "(a)", "Input text"
print "(a)", input_txt

Expand Down
31 changes: 31 additions & 0 deletions tests/test_more_inputs.f90
Original file line number Diff line number Diff line change
@@ -0,0 1,31 @@
program test_more_inputs
use driver, only: gpt2_driver2
implicit none

integer, parameter :: input_ref(*) = [46, 358, 129, 247, 68, 73, 34754, 234, &
861, 8836, 74, 373, 4642, 287]
integer, parameter :: output_ref(*) = [1248, 5332, 287, 262, 7404, 286, &
25370, 254, 368, 83, 6557, 81, 11]
integer, allocatable :: input(:), output(:)

call gpt2_driver2("Ondřej Čertík was born in ", 13, input, output)

print *
print *, "TESTS:"

if (all(input == input_ref)) then
print *, "Input tokens agree with reference results"
else
print *, "Input tokens DO NOT agree with reference results"
error stop
end if

if (all(output == output_ref)) then
print *, "Output tokens agree with reference results"
else
print *, "Output tokens DO NOT agree with reference results"
error stop
end if


end program

0 comments on commit 5ff9e51

Please sign in to comment.