Skip to content

Commit

Permalink
Fix to make compatible with MarginRankingCriterion (#108)
Browse files Browse the repository at this point in the history
* Fix to make compatible with MarginRankingCriterion
  • Loading branch information
Abhi Agg authored and soumith committed Apr 28, 2016
1 parent 59deaaa commit c131490
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion ModuleFromCriterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 24,19 @@ end
function ModuleFromCriterion:updateGradInput(input, gradOutput)
local prediction, target = unpack(input)
local gradPrediction = self.criterion:updateGradInput(prediction, target)
self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1])
if type(gradPrediction) == 'table' then
if type(self.gradInput[1]) ~= 'table' then
self.gradInput[1] = {} -- initializing to table first time if it is tensor (which it is: line 10)
for i=1, #gradPrediction do
self.gradInput[1][i] = gradPrediction[i].new() -- and putting tensors of right size inside.
end
end
for i=1, #gradPrediction do
self.gradInput[1][i]:resizeAs(gradPrediction[i]):copy(gradPrediction[i]):mul(gradOutput[1])
end
else
self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1])
end
self.gradInput[2]:resizeAs(target):zero()
return self.gradInput
end

0 comments on commit c131490

Please sign in to comment.