Remove hardcoded device="cuda" to support more devices (#2503)
Co-authored-by: Jiang Li <jiang1.li@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -77,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
self.embedding_dim,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.weight, {
|
||||
"parallel_dim": 0,
|
||||
@@ -139,7 +138,6 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"parallel_dim": 0,
|
||||
|
||||
Reference in New Issue
Block a user