Add w2nn.Print for debug
This commit is contained in:
parent
0fe21eef70
commit
1dc32aaa89
15
lib/Print.lua
Normal file
15
lib/Print.lua
Normal file
|
@ -0,0 +1,15 @@
|
|||
local Print, parent = torch.class('w2nn.Print','nn.Module')
|
||||
|
||||
function Print:__init()
|
||||
parent.__init(self)
|
||||
end
|
||||
function Print:updateOutput(input)
|
||||
print(input:size())
|
||||
self.output:resizeAs(input)
|
||||
self.output:copy(input)
|
||||
return self.output
|
||||
end
|
||||
function Print:updateGradInput(input, gradOutput)
|
||||
self.gradInput:resizeAs(GradOutput)
|
||||
return self.gradInput
|
||||
end
|
16
lib/PrintTable.lua
Normal file
16
lib/PrintTable.lua
Normal file
|
@ -0,0 +1,16 @@
|
|||
local PrintTable, parent = torch.class('w2nn.PrintTable','nn.Module')
|
||||
|
||||
function PrintTable:__init(id)
|
||||
parent.__init(self)
|
||||
self.id = id
|
||||
end
|
||||
function PrintTable:updateOutput(input)
|
||||
print("----", self.id)
|
||||
print(input)
|
||||
self.output = input
|
||||
return self.output
|
||||
end
|
||||
function PrintTable:updateGradInput(input, gradOutput)
|
||||
self.gradInput = gradOutput
|
||||
return self.gradInput
|
||||
end
|
|
@ -75,5 +75,7 @@ else
|
|||
require 'InplaceClip01'
|
||||
require 'L1Criterion'
|
||||
require 'ShakeShakeTable'
|
||||
require 'PrintTable'
|
||||
require 'Print'
|
||||
return w2nn
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue