Remove hardcoded device="cuda" to support more devices (#2503)
Co-authored-by: Jiang Li <jiang1.li@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -54,7 +54,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
@@ -113,9 +112,7 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
self.register_parameter(name, weight)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=self.params_dtype))
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
set_weight_attrs(self.bias, {"output_dim": 0})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
@@ -183,7 +180,6 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
@@ -509,9 +505,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
|
||||
Reference in New Issue
Block a user